from typing import List, Tuple, Optional, Union
import torch
from gpytorch.models import ApproximateGP
from gpytorch.mlls import VariationalELBO, ExactMarginalLogLikelihood
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.distributions import MultivariateNormal
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
from botorch.acquisition import qNoisyExpectedImprovement
from botorch.acquisition.objective import GenericMCObjective
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.transforms.outcome import Standardize
from botorch.models import SingleTaskGP, SingleTaskVariationalGP
from botorch.models.model_list_gp_regression import ModelListGP
from gpytorch.kernels import ScaleKernel, MaternKernel
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms.input import Normalize
import gpytorch
import random
import numpy as np
from scipy.stats import dirichlet
from dataclasses import dataclass, fields

from MOCO.problems import BiObjectiveTSP, MultiObjectiveKnapsack, TriObjectiveTSP  # Add TriObjectiveTSP
from MOCO.evaluation import MOCOEvaluator

import multiprocessing as mp
from functools import partial

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore")

@dataclass
class MOBOConfig:
    """Configuration for MOBO-qParEGO with Standard GP"""
    n_initial: int = 20
    n_iterations: int = 50
    q: int = 4  # Batch size
    pop_size: int = 100
    n_generations: int = 30
    crossover_prob: float = 0.8
    mutation_prob: float = 0.2
    tournament_size: int = 3
    rho: float = 0.05  # Augmented Chebyshev parameter
    
    # GP parameters
    matern_nu: float = 2.5  # Matern kernel smoothness parameter
    use_sparse_gp: bool = True  # Default to sparse GP
    model_rebuild_interval: int = 3  # Only rebuild GP model every N iterations
    
    def to_dict(self):
        """Convert config to dictionary for JSON serialization"""
        return {f.name: getattr(self, f.name) for f in fields(self)}

    @classmethod
    def from_dict(cls, config_dict):
        """Create config from dictionary"""
        # Remove any unexpected keys
        valid_keys = {f.name for f in fields(cls)}
        filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
        return cls(**filtered_dict)

import torch
from gpytorch.kernels import Kernel

class KendallKernel(Kernel):
    """
    Kendall kernel for permutations.
    Computes similarity based on pairwise concordance between permutations.

    Args:
        normalize (bool): If True, normalizes kernel to [0, 1].
    """
    has_lengthscale = False

    def __init__(self, normalize: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.normalize = normalize

    def forward(self, x1, x2, diag=False, **params):
        """
        Args:
            x1, x2: Tensors of permutations (shape: [n, d])
                    Each row is a permutation represented by integer indices [0..d-1]
                    or continuous embeddings that can be argsorted to permutations.
        Returns:
            kernel matrix (n1 x n2)
        """
        if diag:
            return torch.ones(x1.size(0), dtype=x1.dtype, device=x1.device)

        # If inputs are not integer permutations, convert to rank order
        if not torch.all(x1.int() == x1):
            x1 = torch.argsort(x1, dim=-1)
        if not torch.all(x2.int() == x2):
            x2 = torch.argsort(x2, dim=-1)

        n1, d = x1.shape
        n2 = x2.shape[0]

        # Compute pairwise Kendall similarity between all x1[i], x2[j]
        K = torch.zeros(n1, n2, dtype=x1.dtype, device=x1.device)
        denom = d * (d - 1) / 2 if self.normalize else 1.0

        for i in range(n1):
            p1 = x1[i]
            # Precompute pairwise order comparisons for p1
            p1_order = (p1.unsqueeze(0) < p1.unsqueeze(1)).float()
            for j in range(n2):
                p2 = x2[j]
                p2_order = (p2.unsqueeze(0) < p2.unsqueeze(1)).float()

                # Concordant pairs: same order in both
                concordant = (p1_order == p2_order).sum()
                # Discordant pairs: opposite order
                discordant = d * (d - 1) - concordant
                kendall_tau = (concordant - discordant) / (d * (d - 1))
                K[i, j] = kendall_tau if self.normalize else kendall_tau * denom

        if self.normalize:
            K = (K + 1.0) / 2.0  # normalize to [0, 1]

        return K


class MOBOqParEGO:
    """MOBO-qParEGO for Combinatorial Optimization with multi-objective support"""
    def __init__(self, 
             problem,
             config: Optional[Union[MOBOConfig, dict]] = None,
             **kwargs):
        """
        Initialize MOBOqParEGO optimizer
        
        Args:
            problem: Optimization problem instance
            config: Configuration object or dictionary
            **kwargs: Additional parameters to override config
        """
        # Merge config with any additional kwargs
        if isinstance(config, dict):
            config = MOBOConfig.from_dict(config)
        elif config is None:
            config = MOBOConfig()
        
        # Override config with any additional kwargs
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
        
        self.problem = problem
        self.config = config
            
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.best_solutions = []
        self.X_train = None
        self.Y_train = None
        
        # FIXED: Detect number of objectives dynamically
        self.n_objectives = self._detect_num_objectives()
        print(f"Detected {self.n_objectives} objectives for problem {type(problem).__name__}")
        
        # Problem-specific attributes - more flexible approach
        self.is_tsp = hasattr(problem, 'n_cities')
        self.is_knapsack = hasattr(problem, 'n_items')
        self.solution_length = (
            problem.n_cities if self.is_tsp 
            else (problem.n_items if hasattr(problem, 'n_items')
                else problem.n_customers)  # Add MOCVRP case
        )
        
        # Model caching to avoid rebuilding every iteration
        self.cached_models = None
        self.iteration_counter = 0
        self.last_X_dim = None  # To detect when data changes significantly

        self.solution_counter = 0

        # Add these
        self.obj_min = None
        self.obj_max = None

        self.dtype = torch.float64
        self.total_evaluations = 0  # ← ADD THIS

    def _detect_num_objectives(self) -> int:
        """Detect number of objectives from the problem"""
        if hasattr(self.problem, 'num_objectives'):
            return self.problem.num_objectives
        elif hasattr(self.problem, 'n_objectives'):
            return self.problem.n_objectives
        else:
            # Try to infer from a dummy evaluation
            try:
                dummy_solution = self.problem.random_solution()
                dummy_objectives = self.problem.evaluate(dummy_solution)
                return len(dummy_objectives)
            except:
                # Default fallback
                return 2


    def _sample_weights(self, n_samples: int = 1) -> torch.Tensor:
        """Sample weights with better diversity"""
        weights_list = []
        
        # For 2D: Only add extreme weights if we have enough samples
        if self.n_objectives == 2:
            if n_samples >= 8: # was 6 before
                # Add 2 extreme weights
                for i in range(self.n_objectives):
                    weight_vector = torch.zeros(self.n_objectives, device=self.device, dtype=torch.float64)
                    weight_vector[i] = 0.95  # More extreme
                    weight_vector[1-i] = 0.05
                    weights_list.append(weight_vector)
                
                # Add balanced weight
                balanced_weight = torch.full((self.n_objectives,), 0.5, 
                                            device=self.device, dtype=torch.float64)
                weights_list.append(balanced_weight)
                
                # ✅ ADD EDGE WEIGHTS FOR 2D
                if n_samples > 3:
                    edge_weights_2d = [
                        [0.8, 0.2], [0.7, 0.3], [0.6, 0.4],  # Favor obj 1
                        [0.2, 0.8], [0.3, 0.7], [0.4, 0.6],  # Favor obj 2
                    ]
                    available_slots = min(len(edge_weights_2d), n_samples - len(weights_list))
                    for i in range(available_slots):
                        edge_weight = torch.tensor(edge_weights_2d[i], device=self.device, dtype=torch.float64)
                        weights_list.append(edge_weight)
                
                remaining = max(0, n_samples - len(weights_list))

                # Fill rest with Dirichlet
                # remaining = n_samples - 3
            else:
                # For small q, just use Dirichlet for diversity
                remaining = n_samples
        
        elif self.n_objectives == 3:
            #  existing 3D logic is fine
            # if n_samples < 20:
            #     n_samples = 20
            
            # Add extreme weights
            for i in range(self.n_objectives):
                weight_vector = torch.zeros(self.n_objectives, device=self.device, dtype=torch.float64)
                weight_vector[i] = 0.9
                remaining_weight = 0.1 / (self.n_objectives - 1)
                for j in range(self.n_objectives):
                    if j != i:
                        weight_vector[j] = remaining_weight
                weights_list.append(weight_vector)
            
            # Add balanced
            if n_samples > self.n_objectives:
                balanced_weight = torch.full((self.n_objectives,), 1.0/self.n_objectives, 
                                            device=self.device, dtype=torch.float64)
                weights_list.append(balanced_weight)
            
            # Add edge weights for 3D
            if n_samples > self.n_objectives + 1:
                edge_weights = [
                    [0.7, 0.25, 0.05], [0.7, 0.05, 0.25],
                    [0.25, 0.7, 0.05], [0.05, 0.7, 0.25],
                    [0.25, 0.05, 0.7], [0.05, 0.25, 0.7],
                ]
                available_slots = min(len(edge_weights), n_samples - len(weights_list))
                for i in range(available_slots):
                    edge_weight = torch.tensor(edge_weights[i], device=self.device, dtype=torch.float64)
                    weights_list.append(edge_weight)
            
            remaining = max(0, n_samples - len(weights_list))
        
        else:  # 4+ objectives
            # Use existing logic from corrected HW-SW version
            remaining = n_samples


        # Fill remaining with Dirichlet samples
        if remaining > 0:
            # Use LESS concentrated alpha for MORE diversity
            alpha = np.ones(self.n_objectives) * (1.0 if self.n_objectives >= 3 else 0.5)
            # if self.n_objectives == 3:
            #     alpha = np.ones(self.n_objectives) * 1.0  # More uniform for 3D
            # else:
            #     alpha = np.ones(self.n_objectives) * 0.5  # More uniform for 2D (was 0.3)
            
            dirichlet_samples = dirichlet.rvs(alpha, size=remaining)
            
            if remaining == 1:
                dirichlet_samples = dirichlet_samples.reshape(1, -1)
            
            for i in range(remaining):
                weight = torch.from_numpy(dirichlet_samples[i]).to(
                    dtype=torch.float64, 
                    device=self.device
                )
                weights_list.append(weight)
        
        return torch.stack(weights_list)

    
    # def _scalarize_objectives(self, Y, weights):
    #     """Scalarize multiple objectives using weight vector - FIXED for multi-objective"""
    #     # Ensure weights is properly shaped
    #     w = weights.view(1, -1)  # Shape: (1, n_objectives)
        
    #     # Compute weighted sum
    #     weighted = Y * w
        
    #     # Compute max and sum terms for augmented Chebyshev
    #     max_term = weighted.max(dim=1, keepdim=True).values
    #     sum_term = self.config.rho * weighted.sum(dim=1, keepdim=True)
        
    #     # Return scalarized objectives
    #     return max_term + sum_term
    
    def _scalarize_objectives(self, Y, weights):
        """
        FIXED: Proper normalization with direction awareness
        
        For TSP: All objectives are minimization (0 = best)
        For BiKP: Override this method (maximization)
        """
        # Add normalization
        self._update_objective_bounds(Y)
        if self.obj_min is not None:
            range_vals = self.obj_max - self.obj_min + 1e-8
            Y_norm = (Y - self.obj_min) / range_vals
        else:
            Y_norm = Y
        
        w = weights.view(1, -1)
        weighted = Y_norm * w
        max_term = weighted.max(dim=1, keepdim=True).values
        sum_term = self.config.rho * weighted.sum(dim=1, keepdim=True)
        return max_term + sum_term

    def _update_objective_bounds(self, objectives: torch.Tensor):
        if self.obj_min is None:
            self.obj_min = objectives.min(dim=0).values
            self.obj_max = objectives.max(dim=0).values
        else:
            self.obj_min = torch.min(self.obj_min, objectives.min(dim=0).values)
            self.obj_max = torch.max(self.obj_max, objectives.max(dim=0).values)

    def _create_initial_solution(self) -> torch.Tensor:
        if self.is_tsp:
            return torch.randperm(self.solution_length)
        else:
            # For knapsack, use progressive diversity
            sol = torch.zeros(self.solution_length, dtype=torch.float64, device=self.device)
            
            # Determine selection probability based on solution counter
            total_initial = self.config.n_initial
            
            if self.solution_counter < total_initial * 0.8:
                # Progressive selection probability for diversity
                selection_prob = 0.1 + (0.7 * self.solution_counter / (total_initial * 0.8))
            else:
                # Random selection probability for exploration
                selection_prob = random.random()
            
            # Track capacity
            remaining_capacity = self.problem.capacity
            
            # Random order to avoid bias
            items = list(range(self.problem.n_items))
            random.shuffle(items)
            
            # Select items based on probability and capacity
            for item_idx in items:
                if remaining_capacity >= self.problem.weights[item_idx]:
                    if random.random() < selection_prob:
                        sol[item_idx] = 1
                        remaining_capacity -= self.problem.weights[item_idx]
            
            # Increment counter for next solution
            self.solution_counter += 1
            
            return sol

    def _tournament_select(self, 
                         population: List[torch.Tensor],
                         fitness_values: List[float]) -> torch.Tensor:
        """Tournament selection"""
        indices = random.sample(range(len(population)), 
                              self.config.tournament_size)
        winner_idx = max(indices, key=lambda i: fitness_values[i])
        return population[winner_idx]

    def _order_crossover(self, 
                    parent1: torch.Tensor,
                    parent2: torch.Tensor) -> torch.Tensor:
        """Order Crossover (OX) for TSP"""
        size = len(parent1)
        # Get random slice positions
        start, end = sorted(random.sample(range(size), 2))
        
        # Create child with -1s
        child = torch.full_like(parent1, -1)
        
        # Step 1: Copy slice from parent1
        child[start:end] = parent1[start:end]
        
        # Step 2: Create ordered list of remaining cities from parent2
        remaining = []
        seen = set(parent1[start:end].tolist())
        
        # First add elements from end to size
        for i in range(end, size):
            elem = parent2[i].item()
            if elem not in seen:
                remaining.append(elem)
                seen.add(elem)
                
        # Then add elements from 0 to end
        for i in range(0, end):
            elem = parent2[i].item()
            if elem not in seen:
                remaining.append(elem)
                seen.add(elem)
                
        # Step 3: Fill remaining positions in child
        remaining_idx = 0
        
        # Fill tail positions first
        for i in range(end, size):
            if remaining_idx < len(remaining):
                child[i] = remaining[remaining_idx]
                remaining_idx += 1
                
        # Then fill head positions
        for i in range(0, start):
            if remaining_idx < len(remaining):
                child[i] = remaining[remaining_idx]
                remaining_idx += 1
        
        return child

    def _uniform_crossover(self,
                          parent1: torch.Tensor,
                          parent2: torch.Tensor) -> torch.Tensor:
        """Uniform crossover for Knapsack"""
        mask = torch.rand_like(parent1) < 0.5
        child = torch.where(mask, parent1, parent2)
        return child

    def _swap_mutation(self, solution: torch.Tensor) -> torch.Tensor:
        """Swap mutation for TSP"""
        mutated = solution.clone()
        i, j = random.sample(range(len(solution)), 2)
        mutated[i], mutated[j] = mutated[j], mutated[i]
        return mutated

    def _bit_flip_mutation(self, solution: torch.Tensor) -> torch.Tensor:
        """Bit flip mutation for Knapsack"""
        mutated = solution.clone()
        idx = random.randrange(len(solution))
        mutated[idx] = 1 - mutated[idx]
        return mutated

    def encode_solution(self, solution: torch.Tensor) -> torch.Tensor:
        if self.is_tsp:
            # Existing TSP encoding
            encoded = torch.zeros(len(solution), dtype=torch.float64, device=self.device)
            for i, city in enumerate(solution):
                encoded[city] = float(i) / float(len(solution))
                # encoded[i] = float(city) / float(len(solution))
        else:
            # For knapsack, just use binary encoding - NO penalty feature
            encoded = solution.to(dtype=torch.float64, device=self.device)
        return encoded
    
    # def encode_solution(self, solution: torch.Tensor) -> torch.Tensor:
    #     if self.is_tsp:
    #         n = len(solution)
    #         # position-of-city encoding: for each city id, store its position / n
    #         encoded = torch.zeros(n, dtype=torch.float32, device=self.device)
    #         for pos, city in enumerate(solution.tolist()):
    #             encoded[int(city)] = float(pos) / float(n)
    #         return encoded
    #     else:
    #         return solution.to(dtype=torch.float32, device=self.device)


    # def _create_scalarization(self, weights: torch.Tensor) -> GenericMCObjective:
    #     """Create scalarization function with augmented Chebyshev - FIXED for multi-objective"""
    #     def scalarize(samples, X=None):
    #         """
    #         Args:
    #             samples: Tensor of shape (n_samples, batch_size, n_objectives)
    #             X: Tensor of input locations (not used in scalarization)
    #         Returns:
    #             Tensor of shape (n_samples, batch_size)
    #         """
    #         # Ensure weights match the sample dimensions
    #         w = weights.view(1, 1, -1)  # Shape: (1, 1, n_objectives)
    #         # Compute weighted objectives
    #         weighted = samples * w  # Broadcasting to match samples shape
            
    #         # Compute max and sum terms
    #         max_term = weighted.max(dim=-1).values  # Shape: (n_samples, batch_size)
    #         sum_term = self.config.rho * weighted.sum(dim=-1)  # Shape: (n_samples, batch_size)
            
    #         return max_term + sum_term
        
    #     return GenericMCObjective(scalarize)

    def _create_scalarization(self, weights: torch.Tensor):
        """Create scalarization with proper min/max handling"""
        def scalarize(samples, X=None):
            if self.obj_min is not None:
                obj_min = self.obj_min.view(1, 1, -1)
                obj_max = self.obj_max.view(1, 1, -1)
                range_vals = obj_max - obj_min + 1e-8
                samples_norm = (samples - obj_min) / range_vals
                
                # ✅ CRITICAL FIX: Invert for minimization
                # if self.is_tsp:
                #     samples_norm = 1.0 - samples_norm
            else:
                samples_norm = samples
            
            w = weights.view(1, 1, -1)
            weighted = samples_norm * w
            max_term = weighted.max(dim=-1).values
            sum_term = self.config.rho * weighted.sum(dim=-1)
            return max_term + sum_term
        
        return GenericMCObjective(scalarize)

    # def optimize_acquisition(self, 
    #                    acq_function,
    #                    weights: torch.Tensor) -> torch.Tensor:
    #     """Optimized acquisition function optimization"""
    #     population = []
    #     fitness_values = []
        
    #     # Add existing Pareto-optimal solutions - but limit to just a few
    #     # if self.best_solutions:
    #     #     for sol, _ in self.best_solutions[:5]:  # Only use top 5
    #     #         solution = torch.tensor(sol, device=self.device)
    #     #         population.append(solution)
        
    #     # Generate remaining solutions randomly
    #     while len(population) < self.config.pop_size:
    #         solution = self._create_initial_solution()
    #         population.append(solution)
        
    #     # Pre-encode solutions once
    #     encoded_population = [self.encode_solution(sol) for sol in population]
        
    #     # Batch evaluate initial population - much faster
    #     with torch.no_grad():
    #         # Prepare batch
    #         batch_encoded = torch.stack([enc.unsqueeze(0) for enc in encoded_population])
            
    #         try:
    #             # Batch evaluation
    #             acq_values = acq_function(batch_encoded)
    #             # Handle NaN/Inf
    #             acq_values = torch.where(
    #                 torch.isnan(acq_values) | torch.isinf(acq_values),
    #                 torch.tensor(-1e10, device=self.device, dtype=torch.float64),
    #                 acq_values
    #             )
    #         except Exception:
    #             # Fallback to individual evaluation if batch fails
    #             acq_values = []
    #             for encoded in encoded_population:
    #                 try:
    #                     val = acq_function(encoded.unsqueeze(0).unsqueeze(1))
    #                     if torch.isnan(val) or torch.isinf(val):
    #                         val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
    #                 except Exception:
    #                     val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
    #                 acq_values.append(val)
    #             acq_values = torch.tensor(acq_values, device=self.device)
            
    #         fitness_values = acq_values.tolist()
        
    #     # Early termination if found good solution
    #     best_fitness_idx = np.argmax(fitness_values)
    #     if fitness_values[best_fitness_idx] > 0.2:  # 0.1 Threshold for "good enough"
    #         return population[best_fitness_idx]
        
    #     # Fewer generations
    #     generations_to_run = min(self.config.n_generations, 30) # 15
        
    #     # GA generations - optimized
    #     for gen in range(generations_to_run):
    #         offspring = []
    #         offspring_encoded = []
            
    #         # Generate offspring more efficiently
    #         while len(offspring) < self.config.pop_size:
    #             # Tournament selection
    #             parent1 = self._tournament_select(population, fitness_values)
    #             parent2 = self._tournament_select(population, fitness_values)
                
    #             # Crossover
    #             if random.random() < self.config.crossover_prob:
    #                 child = (self._order_crossover(parent1, parent2) if self.is_tsp
    #                         else self._uniform_crossover(parent1, parent2))
    #             else:
    #                 child = parent1.clone()
                
    #             # Mutation
    #             if random.random() < self.config.mutation_prob:
    #                 child = (self._swap_mutation(child) if self.is_tsp
    #                         else self._bit_flip_mutation(child))
                
    #             offspring.append(child)
    #             offspring_encoded.append(self.encode_solution(child))
            
    #         # Batch evaluate offspring - much faster
    #         with torch.no_grad():
    #             # Prepare batch
    #             batch_encoded = torch.stack([enc.unsqueeze(0) for enc in offspring_encoded])
                
    #             try:
    #                 # Batch evaluation
    #                 acq_values = acq_function(batch_encoded)
    #                 # Handle NaN/Inf
    #                 acq_values = torch.where(
    #                     torch.isnan(acq_values) | torch.isinf(acq_values),
    #                     torch.tensor(-1e10, device=self.device, dtype=torch.float64),
    #                     acq_values
    #                 )
    #             except Exception:
    #                 # Fallback to individual evaluation if batch fails
    #                 acq_values = []
    #                 for encoded in offspring_encoded:
    #                     try:
    #                         val = acq_function(encoded.unsqueeze(0).unsqueeze(1))
    #                         if torch.isnan(val) or torch.isinf(val):
    #                             val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
    #                     except Exception:
    #                         val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
    #                     acq_values.append(val)
    #                 acq_values = torch.tensor(acq_values, device=self.device)
                
    #             offspring_fitness = acq_values.tolist()
            
    #         # Early termination if found good solution
    #         best_offspring_idx = np.argmax(offspring_fitness)
    #         if offspring_fitness[best_offspring_idx] > 0.2:  # Higher threshold for early termination
    #             return offspring[best_offspring_idx]
            
    #         # Quick selection for next generation - just take the best
    #         combined_pop = population + offspring
    #         combined_fit = fitness_values + offspring_fitness
    #         combined_enc = encoded_population + offspring_encoded
            
    #         # Sort by fitness (descending)
    #         sorted_indices = np.argsort(combined_fit)[::-1]
            
    #         # Select top individuals
    #         population = [combined_pop[i] for i in sorted_indices[:self.config.pop_size]]
    #         fitness_values = [combined_fit[i] for i in sorted_indices[:self.config.pop_size]]
    #         encoded_population = [combined_enc[i] for i in sorted_indices[:self.config.pop_size]]
        
    #     # Return best solution
    #     best_idx = np.argmax(fitness_values)
    #     return population[best_idx]

    def optimize_acquisition(self, acq_function, weights: torch.Tensor) -> torch.Tensor:
        population = []
        fitness_values = []
        
        # Generate initial population - NO seeding
        while len(population) < self.config.pop_size:
            solution = self._create_initial_solution()
            population.append(solution)
        
        # Pre-encode solutions
        encoded_population = [self.encode_solution(sol) for sol in population]
        
        # Batch evaluate initial population
        with torch.no_grad():
            batch_encoded = torch.stack([enc.unsqueeze(0) for enc in encoded_population])
            try:
                acq_values = acq_function(batch_encoded)
                acq_values = torch.where(
                    torch.isnan(acq_values) | torch.isinf(acq_values),
                    torch.tensor(-1e10, device=self.device, dtype=torch.float64),
                    acq_values
                )
            except Exception:
                acq_values = []
                for encoded in encoded_population:
                    try:
                        val = acq_function(encoded.unsqueeze(0).unsqueeze(1))
                        if torch.isnan(val) or torch.isinf(val):
                            val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
                    except Exception:
                        val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
                    acq_values.append(val)
                acq_values = torch.tensor(acq_values, device=self.device)
            
            fitness_values = acq_values.tolist()
        
        # REMOVE: Early termination check
        
        # Run full GA - NO shortcuts
        for gen in range(self.config.n_generations):
            offspring = []
            offspring_encoded = []
            
            # Generate offspring
            while len(offspring) < self.config.pop_size:
                parent1 = self._tournament_select(population, fitness_values)
                parent2 = self._tournament_select(population, fitness_values)
                
                if random.random() < self.config.crossover_prob:
                    child = (self._order_crossover(parent1, parent2) if self.is_tsp
                            else self._uniform_crossover(parent1, parent2))
                else:
                    child = parent1.clone()
                
                if random.random() < self.config.mutation_prob:
                    child = (self._swap_mutation(child) if self.is_tsp
                            else self._bit_flip_mutation(child))
                
                offspring.append(child)
                offspring_encoded.append(self.encode_solution(child))
            
            # Batch evaluate offspring
            with torch.no_grad():
                batch_encoded = torch.stack([enc.unsqueeze(0) for enc in offspring_encoded])
                try:
                    acq_values = acq_function(batch_encoded)
                    acq_values = torch.where(
                        torch.isnan(acq_values) | torch.isinf(acq_values),
                        torch.tensor(-1e10, device=self.device, dtype=torch.float64),
                        acq_values
                    )
                except Exception:
                    acq_values = []
                    for encoded in offspring_encoded:
                        try:
                            val = acq_function(encoded.unsqueeze(0).unsqueeze(1))
                            if torch.isnan(val) or torch.isinf(val):
                                val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
                        except Exception:
                            val = torch.tensor(-1e10, device=self.device, dtype=torch.float64)
                        acq_values.append(val)
                    acq_values = torch.tensor(acq_values, device=self.device)
                
                offspring_fitness = acq_values.tolist()
            
            # REMOVE: Early termination in offspring
            
            # Selection - keep best
            combined_pop = population + offspring
            combined_fit = fitness_values + offspring_fitness
            combined_enc = encoded_population + offspring_encoded
            
            sorted_indices = np.argsort(combined_fit)[::-1]
            
            population = [combined_pop[i] for i in sorted_indices[:self.config.pop_size]]
            fitness_values = [combined_fit[i] for i in sorted_indices[:self.config.pop_size]]
            encoded_population = [combined_enc[i] for i in sorted_indices[:self.config.pop_size]]
        
        # Return best solution
        best_idx = np.argmax(fitness_values)
        return population[best_idx]


    # def build_model_list(self):
    #     """Build a ModelList combining single-objective models - FIXED for multi-objective"""
    #     # Update iteration counter
    #     self.iteration_counter += 1
        
    #     # Check if we should use cached models
    #     if (self.cached_models is not None and 
    #         self.iteration_counter % self.config.model_rebuild_interval != 1 and
    #         self.X_train.shape[0] == self.last_X_dim):
    #         print("  Using cached GP models")
    #         return self.cached_models
        
    #     print("  Building new GP models")
        
    #     # Ensure all inputs are on the same device
    #     X_train_device = self.X_train.to(self.device, dtype=torch.float64)
    #     Y_train_device = self.Y_train.to(self.device, dtype=torch.float64)
        
    #     # Create a list to hold the individual models
    #     models = []
        
    #     # FIXED: Use actual number of objectives from Y_train
    #     n_objectives = self.Y_train.shape[1]
        
    #     if self.config.use_sparse_gp:
    #         # Build a sparse GP model for each objective
    #         for i in range(n_objectives):
    #             # Extract this objective's values
    #             Y_obj = Y_train_device[:, i:i+1]
                
    #             # Create sparse GP with Standardize transform
    #             model = SingleTaskVariationalGP(
    #                 X_train_device, 
    #                 Y_obj,
    #                 outcome_transform=Standardize(m=1)
    #             )
    #             model = model.to(self.device)
                
    #             # Fit the model
    #             mll = VariationalELBO(model.likelihood, model.model, num_data=Y_obj.size(0))
    #             from botorch.fit import fit_gpytorch_mll
    #             fit_gpytorch_mll(mll)
                
    #             models.append(model)
    #     else:
    #         # Use standard GP
    #         for i in range(n_objectives):
    #             # Extract this objective's values
    #             Y_obj = Y_train_device[:, i:i+1]
                
    #             # Create standard GP with Standardize transform
    #             model = SingleTaskGP(
    #                 X_train_device, 
    #                 Y_obj,
    #                 outcome_transform=Standardize(m=1)
    #             )
    #             model = model.to(self.device)
                
    #             # Replace the default kernel with a Matern kernel
    #             model.covar_module = ScaleKernel(
    #                 MaternKernel(nu=self.config.matern_nu, ard_num_dims=X_train_device.shape[1]),
    #                 outputscale_prior=gpytorch.priors.GammaPrior(2.0, 0.15)
    #             ).to(self.device)
                
    #             # Fit the model
    #             from botorch.fit import fit_gpytorch_mll
    #             from gpytorch.mlls import ExactMarginalLogLikelihood
    #             mll = ExactMarginalLogLikelihood(model.likelihood, model)
    #             fit_gpytorch_mll(mll, options={
    #                     'maxiter': 80,     # Instead of default 100-200
    #                     'ftol': 1e-5,      # Less strict convergence  
    #                     'lr': 0.1,         # Higher learning rate)
    #                     })
    #             models.append(model)

    #     # Create a ModelListGP with the individual models
    #     model_list = ModelListGP(*models)

    #     # Cache the model list and input dimension
    #     self.cached_models = model_list
    #     self.last_X_dim = self.X_train.shape[0]

    #     return model_list

    # def build_model_list(self):
    #     """Build a ModelList with numerical stability improvements"""
    #     self.iteration_counter += 1
        
    #     # CRITICAL: Limit training data size
    #     MAX_TRAIN_SIZE = 100  # Reduced from unlimited
    #     if self.X_train.shape[0] > MAX_TRAIN_SIZE:
    #         # Keep most recent points (they're most relevant)
    #         self.X_train = self.X_train[-MAX_TRAIN_SIZE:]
    #         self.Y_train = self.Y_train[-MAX_TRAIN_SIZE:]
    #         print(f"  Reduced training data to {MAX_TRAIN_SIZE} points")
        
    #     # Remove duplicates and near-duplicates
    #     self.X_train, self.Y_train = self._remove_duplicates(self.X_train, self.Y_train)
        
    #     print(f"  Building new GP models with {self.X_train.shape[0]} points")
        
    #     X_train_device = self.X_train.to(self.device, dtype=torch.float64)
    #     Y_train_device = self.Y_train.to(self.device, dtype=torch.float64)
        
    #     models = []
    #     n_objectives = self.Y_train.shape[1]
        
    #     for i in range(n_objectives):
    #         Y_obj = Y_train_device[:, i:i+1]
            
    #         # Use simpler model for stability
    #         if self.X_train.shape[0] > 50:
    #             # Switch to sparse GP for larger datasets
    #             from botorch.models import SingleTaskVariationalGP
    #             model = SingleTaskVariationalGP(
    #                 X_train_device, 
    #                 Y_obj,
    #                 outcome_transform=Standardize(m=1),
    #                 inducing_points=X_train_device[:min(25, len(X_train_device))]
    #             )
    #         else:
    #             # Standard GP for small datasets
    #             model = SingleTaskGP(
    #                 X_train_device, 
    #                 Y_obj,
    #                 outcome_transform=Standardize(m=1)
    #             )
            
    #         model = model.to(self.device)
            
    #         # Robust fitting with fallback
    #         fitted = False
    #         for attempt in range(3):
    #             try:
    #                 if hasattr(model, 'likelihood'):
    #                     from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
                        
    #                     if isinstance(model, SingleTaskVariationalGP):
    #                         mll = VariationalELBO(model.likelihood, model.model, 
    #                                             num_data=Y_obj.shape[0])
    #                     else:
    #                         mll = ExactMarginalLogLikelihood(model.likelihood, model)
                        
    #                     # Progressive relaxation of fitting parameters
    #                     options_list = [
    #                         {'maxiter': 20, 'lr': 0.01},  # Quick attempt
    #                         {'maxiter': 10, 'lr': 0.005},  # Even quicker
    #                         {'maxiter': 5}  # Minimal fitting
    #                     ]
                        
    #                     from botorch.fit import fit_gpytorch_mll
    #                     fit_gpytorch_mll(mll, options=options_list[attempt])
    #                     fitted = True
    #                     break
    #             except Exception as e:
    #                 if attempt == 2:
    #                     print(f"  Warning: Could not fit model for objective {i}, using default")
    #                     # Don't fail - use the unfitted model with default parameters
                        
    #         models.append(model)
        
    #     return ModelListGP(*models)

    def build_model_list(self):
        """Build a ModelListGP with ARD Matern, input normalization, and robust fitting.
        - TSP: KendallKernel (permutations)
        - Knapsack: MaternKernel (binary vectors)
        """
        self.iteration_counter += 1

        # Safety: small training set cap to keep fitting stable
        MAX_TRAIN_SIZE = 100
        if self.X_train.shape[0] > MAX_TRAIN_SIZE:
            self.X_train = self.X_train[-MAX_TRAIN_SIZE:]
            self.Y_train = self.Y_train[-MAX_TRAIN_SIZE:]
            print(f"  Reduced training data to {MAX_TRAIN_SIZE} points")

        # Remove duplicates
        self.X_train, self.Y_train = self._remove_duplicates(self.X_train, self.Y_train)

        n_points = self.X_train.shape[0]
        if n_points == 0:
            raise RuntimeError("No training points available to build models")

        print(f"  Building new GP models with {n_points} points (dtype={self.dtype}, device={self.device})")

        # Move data to device and dtype (use self.dtype)
        X_train_device = self.X_train.to(device=self.device, dtype=self.dtype)
        Y_train_device = self.Y_train.to(device=self.device, dtype=self.dtype)

        models = []
        n_objectives = Y_train_device.shape[1]

        for i in range(n_objectives):
            Y_obj = Y_train_device[:, i:i+1]

            use_variational = n_points > 50

            # Choose inducing points for sparse GP (if needed)
            if use_variational:
                # pick up to min(25, n_points // 2) inducing points uniformly from X
                n_inducing = min(25, max(5, n_points // 4))
                inducing_idx = torch.randperm(n_points)[:n_inducing]
                inducing_points = X_train_device[inducing_idx].clone()
            else:
                inducing_points = None

            # Build model with input & outcome transforms
            if use_variational:
                model = SingleTaskVariationalGP(
                    X_train_device,
                    Y_obj,
                    input_transform=Normalize(d=X_train_device.shape[1]),
                    outcome_transform=Standardize(m=1),
                    inducing_points=inducing_points
                )
            else:
                model = SingleTaskGP(
                    X_train_device,
                    Y_obj,
                    input_transform=Normalize(d=X_train_device.shape[1]),
                    outcome_transform=Standardize(m=1)
                )

            try:
                if self.is_tsp:
                    base_kernel = KendallKernel(normalize=True)
                    print(f"    Using KendallKernel for TSP objective {i}")
                else:
                    # Knapsack: Use ARD Matern kernel for binary/continuous
                    base_kernel = MaternKernel(nu=self.config.matern_nu, ard_num_dims=X_train_device.shape[1]) # Set ARD Matern kernel (important for high-dim)
                    print(f"    Using Matern ARD kernel for Knapsack objective {i}")

            
                model.covar_module = ScaleKernel(base_kernel,outputscale_prior=gpytorch.priors.GammaPrior(2.0, 0.15))
            
            except Exception as e:
                print("  Warning: could not set ARD Matern kernel:", e)

            # Move model to device & dtype (model parameters will be in float32/64 as default);
            # gpytorch/BoTorch prefer float32 on GPU, so ensure dtype consistency
            model.to(device=self.device)
            try:
                model = model.to(dtype=self.dtype)
            except Exception:
                # some model components may not support direct dtype change — ignore if so
                pass

            # Robust fitting with progressive relaxation
            fitted = False
            fit_attempts = [
                {'maxiter': 80, 'lr': 0.1},
                {'maxiter': 40, 'lr': 0.05},
                {'maxiter': 10, 'lr': 0.01},
            ]
            for attempt_idx, opts in enumerate(fit_attempts):
                try:
                    if isinstance(model, SingleTaskVariationalGP):
                        mll = VariationalELBO(model.likelihood, model.model, num_data=Y_obj.shape[0])
                    else:
                        mll = ExactMarginalLogLikelihood(model.likelihood, model)

                    fit_gpytorch_mll(mll, options=opts)
                    fitted = True
                    break
                except Exception as e:
                    print(f"  Fit attempt {attempt_idx+1} for objective {i} failed: {e}")
                    # Try to relax priors / reinitialize lengthscale if repeated failures
                    try:
                        # weakly reinitialize lengthscale / outputscale
                        if hasattr(model, "covar_module"):
                            # safe reinit
                            for p in model.covar_module.parameters():
                                if torch.isfinite(p).all():
                                    p.data = p.data * 0.5
                    except Exception:
                        pass

            if not fitted:
                print(f"  Warning: Could not fit model for objective {i}; using unfitted model (predictions may be poor)")

            models.append(model)

        model_list = ModelListGP(*models)

        # Optional: cache the models
        self.cached_models = model_list
        self.last_X_dim = self.X_train.shape[0]

        return model_list


    def _remove_duplicates(self, X, Y, tol=1e-4):
        """Remove duplicate and near-duplicate points"""
        keep_indices = []
        for i in range(X.shape[0]):
            is_duplicate = False
            for j in keep_indices:
                if torch.norm(X[i] - X[j]) < tol:
                    is_duplicate = True
                    break
            if not is_duplicate:
                keep_indices.append(i)
        
        if len(keep_indices) < X.shape[0]:
            print(f"  Removed {X.shape[0] - len(keep_indices)} duplicate points")
        
        return X[keep_indices], Y[keep_indices]

    def run(self) -> List[Tuple[List[int], List[float]]]:
        """Run MOBO-qParEGO optimization with batch acquisition (q points per iteration)"""
        # Initial design
        X_init = []
        Y_init = []
        
        # Initial design - simplified
        print("Generating initial solutions...")
        attempts = 0
        max_attempts = self.config.n_initial * 5  # Fewer attempts
        
        while len(X_init) < self.config.n_initial and attempts < max_attempts:
            try:
                # Generate initial solution
                solution = self._create_initial_solution()
                
                # Evaluate solution
                objectives = self.problem.evaluate(solution.tolist())

                self.total_evaluations += 1  # ← COUNT HERE
                
                # Check for valid objectives
                if not any(np.isnan(obj) for obj in objectives) and not any(np.isinf(obj) for obj in objectives):
                    # Encode solution
                    encoded = self.encode_solution(solution)
                    
                    # Convert objectives to tensor
                    obj_tensor = torch.tensor(
                        objectives, 
                        dtype=torch.float64, 
                        device=self.device
                    ).view(1, -1)
                    
                    X_init.append(encoded)
                    Y_init.append(obj_tensor)
                    self._add_if_nondominated(solution.tolist(), list(objectives))
                
                attempts += 1
            except Exception:
                attempts += 1
        
        # Check if we have any valid initial solutions
        if not X_init:
            raise ValueError("Could not generate any valid initial solutions")
        
        # Convert to tensors
        self.X_train = torch.stack(X_init)
        self.Y_train = torch.cat(Y_init, dim=0)
        
        print(f"Initial design complete: {len(X_init)} solutions, {self.Y_train.shape[1]} objectives")
        
        # Main optimization loop - properly implementing qParEGO
        for iteration in range(self.config.n_iterations):
            print(f"Iteration {iteration+1}/{self.config.n_iterations}")
            
            # Build model list ONCE per iteration
            model_list = self.build_model_list()
            
            # For qParEGO, generate q different weight vectors per iteration
            # FIXED: Now uses correct number of objectives
            weights_batch = self._sample_weights(n_samples=self.config.q)
            print(f"  Using {len(weights_batch)} weight vectors for this iteration")
            
            # Collect all new solutions for this batch
            batch_solutions = []
            batch_objectives = []
            batch_encoded = []
            
            # For each weight vector in the batch (q different scalarizations)
            for batch_idx, weights in enumerate(weights_batch):
                try:
                    # Skip if not enough training data
                    if len(self.X_train) < 2:
                        break
                    
                    print(f"  Processing weight vector {batch_idx+1}/{len(weights_batch)}: {weights.tolist()}")
                    
                    # Create the acquisition function with scalarization for this weight vector
                    objective = self._create_scalarization(weights)
                    
                    # Use qNoisyExpectedImprovement with q=1 for each scalarization
                    acq_function = qNoisyExpectedImprovement(
                        model=model_list,
                        objective=objective,
                        X_baseline=self.X_train,
                        sampler=SobolQMCNormalSampler(sample_shape=torch.Size([64]))
                    )
                    
                    # Optimize acquisition function
                    solution = self.optimize_acquisition(acq_function, weights)
                    
                    # Evaluate solution
                    objectives = self.problem.evaluate(solution.tolist())

                    self.total_evaluations += 1  # ← COUNT HERE
                    
                    # Check for valid objectives
                    if any(np.isnan(obj) for obj in objectives) or any(np.isinf(obj) for obj in objectives):
                        continue
                    
                    # Encode solution
                    encoded = self.encode_solution(solution)
                    
                    # Add to batch
                    batch_solutions.append(solution.tolist())
                    batch_objectives.append(objectives)
                    batch_encoded.append(encoded)
                    
                    # Update Pareto front right away
                    self._add_if_nondominated(solution.tolist(), list(objectives))
                    
                except Exception as e:
                    print(f"Error in optimization iteration: {e}")
            
            # After evaluating all q points, update the model with the batch
            if batch_encoded:
                # Convert batch to tensors
                batch_encoded_tensor = torch.stack(batch_encoded)
                batch_objectives_tensor = torch.tensor(
                    batch_objectives,
                    dtype=torch.float64,
                    device=self.device
                )
                
                # Update training data all at once
                self.X_train = torch.vstack([self.X_train, batch_encoded_tensor])
                self.Y_train = torch.vstack([self.Y_train, batch_objectives_tensor])
                
                print(f"  Added {len(batch_encoded)} new points to training data")
                print(f"  Current training data size: {len(self.X_train)}")
                print(f"  Current Pareto front size: {len(self.best_solutions)}")
        
        # Return the best solutions found
        return [(sol, obj.tolist()) for sol, obj in self.best_solutions]

    # def _add_if_nondominated(self, 
    #                     solution: List[int],
    #                     objectives: List[float]) -> None:
    #     """Update Pareto front with proper dominance for problem type - FIXED for multi-objective"""
    #     objectives = np.array(objectives)
        
    #     print(f"  DEBUG: Adding solution with objectives: {objectives}")
        
    #     if len(self.best_solutions) == 0:
    #         self.best_solutions.append((solution, objectives))
    #         print("    First solution added to Pareto front")
    #         return
        
    #     dominated = False
    #     solutions_to_remove = []
        
    #     for i, (_, existing_obj) in enumerate(self.best_solutions):
    #         if self.is_knapsack:  # MAXIMIZATION
    #             # For maximization: A dominates B if A >= B in all and A > B in at least one
    #             if np.all(objectives >= existing_obj) and np.any(objectives > existing_obj):
    #                 solutions_to_remove.append(i)
    #                 print(f"    Solution {i} will be removed (dominated by new solution)")
    #             elif np.all(existing_obj >= objectives) and np.any(existing_obj > objectives):
    #                 dominated = True
    #                 print(f"    New solution is dominated by solution {i}, will not be added")
    #                 break
    #         else:  # MINIMIZATION (TSP)
    #             # For minimization: A dominates B if A <= B in all and A < B in at least one
    #             if np.all(objectives <= existing_obj) and np.any(objectives < existing_obj):
    #                 solutions_to_remove.append(i)
    #                 print(f"    Solution {i} will be removed (dominated by new solution)")
    #             elif np.all(existing_obj <= objectives) and np.any(existing_obj < objectives):
    #                 dominated = True
    #                 print(f"    New solution is dominated by solution {i}, will not be added")
    #                 break
        
    #     if not dominated:
    #         self.best_solutions = [
    #             sol for i, sol in enumerate(self.best_solutions)
    #             if i not in solutions_to_remove
    #         ]
    #         self.best_solutions.append((solution, objectives))
    #         print(f"    Solution added. Current Pareto front size: {len(self.best_solutions)}")
    #     else:
    #         print("    Solution not added (dominated)")

    def _add_if_nondominated(self, 
                    solution: List[int],
                    objectives: List[float]) -> None:
        """Update Pareto front with proper dominance for problem type - FIXED for multi-objective"""
        objectives = np.array(objectives)
        
        print(f"  DEBUG: Adding solution with objectives: {objectives}")
        
        # DIVERSITY FIX: Always keep first few solutions for initial diversity
        if len(self.best_solutions) < 20:
            # Check if it's not too similar to existing ones
            is_diverse = True
            for _, existing_obj in self.best_solutions:
                distance = np.linalg.norm(objectives - existing_obj)
                if distance < 0.5:  # Too similar, skip
                    is_diverse = False
                    break
            
            if is_diverse:
                self.best_solutions.append((solution, objectives))
                print(f"    Added for initial diversity. Pareto front size: {len(self.best_solutions)}")
                return
        
        dominated = False
        solutions_to_remove = []
        
        # EPSILON-DOMINANCE for better diversity (optional but recommended)
        epsilon = 0.1  # Small epsilon for diversity
        
        for i, (_, existing_obj) in enumerate(self.best_solutions):
            if self.is_knapsack:  # MAXIMIZATION
                # For maximization: A dominates B if A >= B in all and A > B in at least one
                if np.all(objectives >= existing_obj) and np.any(objectives > existing_obj):
                    solutions_to_remove.append(i)
                    print(f"    Solution {i} will be removed (dominated by new solution)")
                elif np.all(existing_obj >= objectives) and np.any(existing_obj > objectives):
                    dominated = True
                    print(f"    New solution is dominated by solution {i}, will not be added")
                    break
            else:  # MINIMIZATION (TSP)
                # For minimization with epsilon-dominance
                if np.all(objectives <= existing_obj - epsilon) and np.any(objectives < existing_obj - epsilon):
                    solutions_to_remove.append(i)
                    print(f"    Solution {i} will be removed (dominated by new solution)")
                elif np.all(existing_obj <= objectives - epsilon) and np.any(existing_obj < objectives - epsilon):
                    dominated = True
                    print(f"    New solution is dominated by solution {i}, will not be added")
                    break
        
        if not dominated:
            # DIVERSITY FIX: Don't let Pareto front shrink too much
            if len(self.best_solutions) - len(solutions_to_remove) < 5:
                # Keep some solutions for diversity even if dominated
                solutions_to_remove = solutions_to_remove[:max(0, len(self.best_solutions) - 5)]
                print(f"    Keeping minimum 5 solutions for diversity")
            
            self.best_solutions = [
                sol for i, sol in enumerate(self.best_solutions)
                if i not in solutions_to_remove
            ]
            self.best_solutions.append((solution, objectives))
            print(f"    Solution added. Current Pareto front size: {len(self.best_solutions)}")
        else:
            # DIVERSITY FIX: If Pareto front too small, add anyway if diverse
            if len(self.best_solutions) < 5:
                is_diverse = True
                for _, existing_obj in self.best_solutions:
                    if np.linalg.norm(objectives - existing_obj) < 2.0:
                        is_diverse = False
                        break
                
                if is_diverse:
                    self.best_solutions.append((solution, objectives))
                    print(f"    Added for diversity despite dominance. Pareto front size: {len(self.best_solutions)}")
            else:
                print("    Solution not added (dominated)")


#------------------------------------------------------------------------------------------------------------------------
@dataclass
class BiKPMOBOConfig(MOBOConfig):
    """Extended configuration for BiKP optimization"""
    # BiKP-specific parameters
    repair_infeasible: bool = True
    use_value_ratio_init: bool = True
    greedy_init_probability: float = 0.7
    capacity_margin: float = 0.95  # Use 95% of capacity for safety
    
    # Optimized parameters for BiKP
    n_initial: int = 40
    n_iterations: int = 20
    q: int = 6  # Good for 2D problems
    pop_size: int = 80
    n_generations: int = 30
    crossover_prob: float = 0.8
    mutation_prob: float = 0.25  # Higher mutation for binary problems
    
    def to_dict(self):
        """Convert config to dictionary for JSON serialization"""
        return {f.name: getattr(self, f.name) for f in fields(self)}


class BiKPMOBOqParEGO(MOBOqParEGO):
    """BiKP-optimized version of MOBOqParEGO with knapsack-specific improvements"""
    
    def __init__(self, problem: MultiObjectiveKnapsack, config: Optional[Union[BiKPMOBOConfig, dict]] = None, **kwargs):
        # Use BiKP-specific config if none provided
        if config is None:
            config = BiKPMOBOConfig()
        elif isinstance(config, dict):
            config = BiKPMOBOConfig.from_dict(config)
        
        # Initialize parent class
        super().__init__(problem, config, **kwargs)
        
        # Verify this is a knapsack problem
        if not hasattr(problem, 'n_items') or not hasattr(problem, 'capacity'):
            raise ValueError("Problem must be a knapsack problem with n_items and capacity attributes")
        
        # BiKP-specific attributes
        self.n_items = problem.n_items
        self.capacity = problem.capacity
        self.item_weights = problem.weights
        self.item_values = [problem.values[i] for i in range(problem.num_objectives)]
        
        # Precompute value/weight ratios for efficient initialization
        self._compute_efficiency_ratios()
        
        # Add these lines after self.efficiency_ratios computation:
        # Track objective bounds for normalization
        self.obj_min = None
        self.obj_max = None

        print(f"BiKP-MOBO initialized: {self.n_items} items, capacity {self.capacity}")
        print(f"Objectives: {self.n_objectives}, Repair enabled: {self.config.repair_infeasible}")
    
    def _compute_efficiency_ratios(self):
        """Precompute value/weight ratios for items"""
        self.efficiency_ratios = []
        for i in range(self.n_items):
            if self.item_weights[i] > 0:
                # Multi-objective efficiency: average value / weight
                avg_value = sum(self.item_values[j][i] for j in range(self.n_objectives)) / self.n_objectives
                ratio = avg_value / self.item_weights[i]
            else:
                ratio = float('inf')  # Zero weight items are infinitely efficient
            self.efficiency_ratios.append((ratio, i))
        
        # Sort by efficiency (descending)
        self.efficiency_ratios.sort(reverse=True)
    
    # def _create_initial_solution(self) -> torch.Tensor:
    #     """Generate feasible knapsack solution using value/weight heuristics"""
    #     solution = torch.zeros(self.solution_length, dtype=torch.float64, device=self.device)
        
    #     if self.config.use_value_ratio_init and random.random() < 0.8:
    #         # Use efficiency-based construction 80% of the time
    #         solution = self._greedy_knapsack_construction()
    #     else:
    #         # Use random construction 20% of the time for diversity
    #         solution = self._random_knapsack_construction()
        
    #     # Ensure feasibility
    #     if self.config.repair_infeasible:
    #         solution = self._repair_solution(solution)
        
    #     self.solution_counter += 1
    #     return solution
    
    def _create_initial_solution(self) -> torch.Tensor:
        if self.is_tsp:
            # Random permutation - already simple
            return torch.randperm(self.solution_length)
        else:
            # Simple random binary with feasibility
            solution = torch.zeros(self.solution_length, dtype=torch.float64, device=self.device)
            
            # Random selection with fixed probability
            selection_prob = 0.3  # Fixed for all solutions
            
            # Random order to avoid bias
            items = list(range(self.problem.n_items))
            random.shuffle(items)
            
            # Add items randomly until capacity
            remaining_capacity = self.problem.capacity
            for item_idx in items:
                if remaining_capacity >= self.problem.weights[item_idx]:
                    if random.random() < selection_prob:
                        solution[item_idx] = 1
                        remaining_capacity -= self.problem.weights[item_idx]
            
            return solution
    
    def _greedy_knapsack_construction(self) -> torch.Tensor:
        """Construct solution using greedy value/weight ratio"""
        solution = torch.zeros(self.solution_length, dtype=torch.float64, device=self.device)
        remaining_capacity = self.capacity * self.config.capacity_margin
        
        # Add items in order of efficiency with some randomness
        for ratio, item_idx in self.efficiency_ratios:
            if remaining_capacity >= self.item_weights[item_idx]:
                # Higher probability for more efficient items
                inclusion_prob = max(0.3, self.config.greedy_init_probability * (ratio / self.efficiency_ratios[0][0]))
                if random.random() < inclusion_prob:
                    solution[item_idx] = 1
                    remaining_capacity -= self.item_weights[item_idx]
        
        return solution
    
    def _random_knapsack_construction(self) -> torch.Tensor:
        """Construct random feasible solution"""
        solution = torch.zeros(self.solution_length, dtype=torch.float64, device=self.device)
        remaining_capacity = self.capacity * self.config.capacity_margin
        
        # Randomly shuffle items and try to add them
        items = list(range(self.n_items))
        random.shuffle(items)
        
        for item_idx in items:
            if remaining_capacity >= self.item_weights[item_idx]:
                if random.random() < 0.5:  # 50% chance to include
                    solution[item_idx] = 1
                    remaining_capacity -= self.item_weights[item_idx]
        
        return solution
    
    def _repair_solution(self, solution: torch.Tensor) -> torch.Tensor:
        """Repair infeasible solution by removing least efficient items"""
        current_weight = self._get_solution_weight(solution)
        
        if current_weight <= self.capacity:
            return solution  # Already feasible
        
        solution = solution.clone()
        
        # Get selected items sorted by efficiency (ascending - remove worst first)
        selected_items = [(self.efficiency_ratios[i][0], idx) for i, (_, idx) in enumerate(self.efficiency_ratios) if solution[idx] > 0]
        selected_items.sort()  # Least efficient first
        
        # Remove items until feasible
        for _, item_idx in selected_items:
            solution[item_idx] = 0
            current_weight -= self.item_weights[item_idx]
            if current_weight <= self.capacity:
                break
        
        return solution
    
    def _get_solution_weight(self, solution: torch.Tensor) -> float:
        """Calculate total weight of solution"""
        return sum(self.item_weights[i] * solution[i].item() for i in range(len(solution)))
    
    def _uniform_crossover(self, parent1: torch.Tensor, parent2: torch.Tensor) -> torch.Tensor:
        """Enhanced uniform crossover with feasibility repair"""
        mask = torch.rand_like(parent1) < 0.5
        child = torch.where(mask, parent1, parent2)
        
        # Repair if needed
        if self.config.repair_infeasible:
            child = self._repair_solution(child)
        
        return child
    
    def _bit_flip_mutation(self, solution: torch.Tensor) -> torch.Tensor:
        """Enhanced bit flip mutation with feasibility consideration"""
        mutated = solution.clone()
        
        # Smart mutation: prefer adding efficient items, removing inefficient ones
        if random.random() < 0.7:  # 70% smart mutation
            current_weight = self._get_solution_weight(mutated)
            remaining_capacity = self.capacity - current_weight
            
            if random.random() < 0.5 and remaining_capacity > 0:
                # Try to add an efficient item that fits
                for ratio, item_idx in self.efficiency_ratios:
                    if mutated[item_idx] == 0 and self.item_weights[item_idx] <= remaining_capacity:
                        if random.random() < 0.3:  # 30% chance to add
                            mutated[item_idx] = 1
                            break
            else:
                # Remove a random selected item
                selected_items = [i for i in range(len(mutated)) if mutated[i] > 0]
                if selected_items:
                    item_to_remove = random.choice(selected_items)
                    mutated[item_to_remove] = 0
        else:
            # Standard random bit flip
            idx = random.randrange(len(solution))
            mutated[idx] = 1 - mutated[idx]
        
        # Repair if needed
        if self.config.repair_infeasible:
            mutated = self._repair_solution(mutated)
        
        return mutated
    
    def _add_if_nondominated(self, solution: List[int], objectives: List[float]) -> None:
        """BiKP-specific dominance (maximization problem)"""
        objectives = np.array(objectives)
        
        if len(self.best_solutions) == 0:
            self.best_solutions.append((solution, objectives))
            return
        
        dominated = False
        solutions_to_remove = []
        
        for i, (_, existing_obj) in enumerate(self.best_solutions):
            # For maximization: A dominates B if A >= B in all and A > B in at least one
            if all(objectives >= existing_obj) and any(objectives > existing_obj):
                solutions_to_remove.append(i)
            elif all(existing_obj >= objectives) and any(existing_obj > objectives):
                dominated = True
                break
        
        if not dominated:
            self.best_solutions = [
                sol for i, sol in enumerate(self.best_solutions)
                if i not in solutions_to_remove
            ]
            self.best_solutions.append((solution, objectives))

    def _update_objective_bounds(self, objectives: torch.Tensor):
        """Update min/max bounds for objective normalization"""
        if self.obj_min is None:
            self.obj_min = objectives.min(dim=0).values
            self.obj_max = objectives.max(dim=0).values
        else:
            self.obj_min = torch.min(self.obj_min, objectives.min(dim=0).values)
            self.obj_max = torch.max(self.obj_max, objectives.max(dim=0).values)

    def _normalize_objectives(self, objectives: torch.Tensor) -> torch.Tensor:
        """Normalize objectives to [0,1] range for BiKP (maximization)"""
        if self.obj_min is None:
            return objectives  # No normalization on first call
        
        # For maximization: 1 = best (max value), 0 = worst (min value)
        range_vals = self.obj_max - self.obj_min + 1e-8
        normalized = (objectives - self.obj_min) / range_vals
        return normalized

    def _scalarize_objectives(self, Y, weights):
        """This REPLACES the parent's _scalarize_objectives for BiKP instances"""
        # Update bounds
        self._update_objective_bounds(Y)
        # Normalize objectives
        Y_norm = self._normalize_objectives(Y)
        # Rest is same as parent
        w = weights.view(1, -1)
        weighted = Y_norm * w
        max_term = weighted.max(dim=1, keepdim=True).values
        sum_term = self.config.rho * weighted.sum(dim=1, keepdim=True)
        return max_term + sum_term
    
    def _create_scalarization(self, weights: torch.Tensor):
        """This REPLACES the parent's _create_scalarization for BiKP instances"""
        def scalarize(samples, X=None):
            if self.obj_min is not None:
                samples_norm = (samples - self.obj_min.view(1, 1, -1)) / (
                    (self.obj_max - self.obj_min).view(1, 1, -1) + 1e-8
                )
            else:
                samples_norm = samples
            w = weights.view(1, 1, -1)
            weighted = samples_norm * w
            max_term = weighted.max(dim=-1).values
            sum_term = self.config.rho * weighted.sum(dim=-1)
            return max_term + sum_term
        
        from botorch.acquisition.objective import GenericMCObjective
        return GenericMCObjective(scalarize)
    
class BiKPMOBOWrapper:
    """Wrapper for running BiKP experiments with MOBOqParEGO"""
    
    def __init__(self, problem: MultiObjectiveKnapsack, **kwargs):
        self.problem = problem
        self.algorithm_name = "BiKP-MOBO-qParEGO"
        
        # Check if config is passed directly
        if 'config' in kwargs:
            config_dict = kwargs['config']
            if isinstance(config_dict, dict):
                self.config = BiKPMOBOConfig.from_dict(config_dict)
            else:
                self.config = config_dict
        else:
            # Use default config and override with individual params
            self.config = BiKPMOBOConfig()
            for key, value in kwargs.items():
                if hasattr(self.config, key):
                    setattr(self.config, key, value)
    
    def run(self) -> List[Tuple[List[int], List[float]]]:
        """Run the BiKP-optimized algorithm"""
        algorithm = BiKPMOBOqParEGO(self.problem, self.config)
        solutions = algorithm.run()
        
        # Feasibility check remains useful as a safety measure
        feasible_solutions = []
        for sol, obj in solutions:
            weight = sum(self.problem.weights[i] * sol[i] for i in range(len(sol)))
            if weight <= self.problem.capacity:
                feasible_solutions.append((sol, obj))
        
        return feasible_solutions



#------------------------------------------------------------------------------------------------------------------------
#------------------------------------------------------------------------------------------------------------------------

class CVRPWrapper:
    """
    Wraps a CVRP problem to look like a TSP problem for MOBOqParEGO.
    
    The wrapper:
    1. Exposes n_cities = n_customers (triggers is_tsp=True in MOBOqParEGO)
    2. Intercepts evaluate() to decode permutation → routes first
    3. Passes through all other attributes to the underlying CVRP
    
    Args:
        cvrp_problem: A BiObjectiveCVRP instance (or any CVRP with same interface)
    """
    
    def __init__(self, cvrp_problem):
        self.cvrp = cvrp_problem
        
        # KEY: This attribute triggers is_tsp=True in MOBOqParEGO
        self.n_cities = cvrp_problem.n_customers
        
        # Store CVRP-specific attributes for route decoding
        self.n_customers = cvrp_problem.n_customers
        self.vehicle_capacity = cvrp_problem.vehicle_capacity
        self.customers = cvrp_problem.customers
        
        # Pass through other attributes
        self.n_vehicles = cvrp_problem.n_vehicles
        self.num_objectives = cvrp_problem.num_objectives
        self.distances = cvrp_problem.distances
    
    def evaluate(self, permutation: List[int]) -> Tuple[float, float]:
        """
        Evaluate a permutation by decoding to routes first.
        
        Args:
            permutation: List of customer indices (0-indexed, like TSP city indices)
                        e.g., [0, 3, 1, 2, 4] for 5 customers
        
        Returns:
            Tuple of (total_distance, makespan)
        """
        # Convert 0-indexed permutation to 1-indexed customer IDs
        customer_order = [int(c) + 1 for c in permutation]
        
        # Decode to routes using greedy bin-packing
        routes = self._split_into_routes(customer_order)
        
        # Evaluate using underlying CVRP
        return self.cvrp.evaluate(routes)
    
    # def _split_into_routes(self, customer_order: List[int]) -> List[List[int]]:
    #     """
    #     Decode a permutation of customers into feasible routes.
    #     Uses greedy bin-packing: add customers to current route until capacity exceeded.
        
    #     Args:
    #         customer_order: List of customer IDs (1-indexed) in visit order
        
    #     Returns:
    #         List of routes, where each route is a list of customer IDs
    #     """
    #     routes = []
    #     current_route = []
    #     current_demand = 0
        
    #     for customer_id in customer_order:
    #         customer_demand = self.customers[customer_id].demand
            
    #         if current_demand + customer_demand <= self.vehicle_capacity:
    #             # Customer fits in current route
    #             current_route.append(customer_id)
    #             current_demand += customer_demand
    #         else:
    #             # Start a new route
    #             if current_route:
    #                 routes.append(current_route)
    #             current_route = [customer_id]
    #             current_demand = customer_demand
        
    #     # Don't forget the last route
    #     if current_route:
    #         routes.append(current_route)
        
    #     return routes
    
    def _split_into_routes(self, perm: List[int]) -> List[List[int]]:
        """
        Decode permutation to routes with makespan awareness.
        Uses a hybrid approach: respects permutation order but tries to balance.
        """
        # First pass: greedy split to get initial routes
        routes = []
        current_route = []
        current_demand = 0
        
        for idx in perm:
            cid = idx + 1
            demand = self.customers[cid].demand
            
            if current_demand + demand <= self.vehicle_capacity:
                current_route.append(cid)
                current_demand += demand
            else:
                if current_route:
                    routes.append(current_route)
                current_route = [cid]
                current_demand = demand
        
        if current_route:
            routes.append(current_route)
        
        # Second pass: balance routes to reduce makespan
        routes = self._balance_routes_for_makespan(routes)
        
        return routes

    def _balance_routes_for_makespan(self, routes: List[List[int]]) -> List[List[int]]:
        """
        Rebalance routes to minimize makespan (max route length).
        Moves customers from longest route to shortest if feasible.
        """
        if len(routes) <= 1:
            return routes
        
        # Calculate route lengths
        def route_length(route):
            if not route:
                return 0
            length = self.distances[0, route[0]]  # depot to first
            for i in range(len(route) - 1):
                length += self.distances[route[i], route[i+1]]
            length += self.distances[route[-1], 0]  # last to depot
            return length
        
        def route_demand(route):
            return sum(self.customers[c].demand for c in route)
        
        # Iterative improvement
        improved = True
        max_iterations = 50
        iteration = 0
        
        while improved and iteration < max_iterations:
            improved = False
            iteration += 1
            
            lengths = [route_length(r) for r in routes]
            demands = [route_demand(r) for r in routes]
            
            longest_idx = np.argmax(lengths)
            shortest_idx = np.argmin(lengths)
            
            if longest_idx == shortest_idx:
                break
            
            longest_route = routes[longest_idx]
            shortest_route = routes[shortest_idx]
            shortest_demand = demands[shortest_idx]
            
            # Try moving last customer from longest to shortest
            if len(longest_route) > 1:
                customer_to_move = longest_route[-1]
                cust_demand = self.customers[customer_to_move].demand
                
                if shortest_demand + cust_demand <= self.vehicle_capacity:
                    # Check if this actually improves makespan
                    new_longest = longest_route[:-1]
                    new_shortest = shortest_route + [customer_to_move]
                    
                    new_longest_len = route_length(new_longest)
                    new_shortest_len = route_length(new_shortest)
                    
                    old_makespan = lengths[longest_idx]
                    new_makespan = max(new_longest_len, new_shortest_len)
                    
                    # Also check against other routes
                    other_max = max(l for i, l in enumerate(lengths) 
                                if i not in [longest_idx, shortest_idx])
                    new_makespan = max(new_makespan, other_max)
                    
                    if new_makespan < old_makespan - 0.01:  # Small improvement threshold
                        routes[longest_idx] = new_longest
                        routes[shortest_idx] = new_shortest
                        improved = True
            
            # Try moving first customer from longest to shortest
            if not improved and len(longest_route) > 1:
                customer_to_move = longest_route[0]
                cust_demand = self.customers[customer_to_move].demand
                
                if shortest_demand + cust_demand <= self.vehicle_capacity:
                    new_longest = longest_route[1:]
                    new_shortest = [customer_to_move] + shortest_route
                    
                    new_longest_len = route_length(new_longest)
                    new_shortest_len = route_length(new_shortest)
                    
                    old_makespan = lengths[longest_idx]
                    new_makespan = max(new_longest_len, new_shortest_len)
                    other_max = max(l for i, l in enumerate(lengths) 
                                if i not in [longest_idx, shortest_idx])
                    new_makespan = max(new_makespan, other_max)
                    
                    if new_makespan < old_makespan - 0.01:
                        routes[longest_idx] = new_longest
                        routes[shortest_idx] = new_shortest
                        improved = True
        
        # Remove empty routes
        routes = [r for r in routes if r]
        
        return routes


    def random_solution(self) -> List[int]:
        """
        Generate a random permutation (for compatibility if needed).
        MOBOqParEGO uses torch.randperm() instead, so this is rarely called.
        """
        perm = list(range(self.n_customers))
        np.random.shuffle(perm)
        return perm
    
    def decode_solution(self, permutation: List[int]) -> List[List[int]]:
        """
        Utility method to decode a permutation to routes.
        Useful for interpreting results after optimization.
        
        Args:
            permutation: 0-indexed permutation from optimizer
        
        Returns:
            List of routes (1-indexed customer IDs)
        """
        customer_order = [int(c) + 1 for c in permutation]
        return self._split_into_routes(customer_order)


# ============================================================================
# RESULT DECODER (for post-processing optimizer output)
# ============================================================================
class CVRPResultDecoder:
    """
    Utility class to decode MOBOqParEGO results back to CVRP routes.
    
    Usage:
        results = optimizer.run()
        decoder = CVRPResultDecoder(wrapped_problem)
        
        for solution, objectives in results:
            routes = decoder.decode(solution)
            print(f"Routes: {routes}, Obj: {objectives}")
    """
    
    def __init__(self, wrapped_problem: CVRPWrapper):
        self.wrapper = wrapped_problem
    
    def decode(self, solution: List[int]) -> List[List[int]]:
        """Decode a permutation solution to routes"""
        return self.wrapper.decode_solution(solution)
    
    def decode_all(self, results: List[Tuple[List[int], List[float]]]) -> List[Tuple[List[List[int]], List[float]]]:
        """Decode all results from optimizer"""
        decoded = []
        for solution, objectives in results:
            routes = self.decode(solution)
            decoded.append((routes, objectives))
        return decoded


class CVRPMOBOqParEGO:
    """
    FIXED Algorithm wrapper for CVRP that works with benchmarking code.
    
    Changes from original:
    1. Uses ISOTROPIC kernel (not ARD) - critical for 50D
    2. Proper tensor shapes in acquisition
    3. Better error handling with full tracebacks
    4. Fallback to random when acquisition fails
    """
    
    def __init__(self, problem, config=None, **kwargs):
        self.original_problem = problem
        self.config = config or {}
        self.kwargs = kwargs
        self.algorithm_name = "CVRP-MOBO-qParEGO"
        
        # Problem attributes
        self.n_customers = problem.n_customers
        self.vehicle_capacity = problem.vehicle_capacity
        self.customers = problem.customers
        
        if isinstance(problem.distances, torch.Tensor):
            self.distances = problem.distances.cpu().numpy()
        else:
            self.distances = np.array(problem.distances)
        
        # Config with defaults
        self.n_initial = self.config.get('n_initial', 50)
        self.n_iterations = self.config.get('n_iterations', 50)
        self.q = self.config.get('q', 4)
        self.pop_size = self.config.get('pop_size', 150)
        self.n_generations = self.config.get('n_generations', 40)
        self.crossover_prob = self.config.get('crossover_prob', 0.85)
        self.mutation_prob = self.config.get('mutation_prob', 0.15)
        self.tournament_size = self.config.get('tournament_size', 3)
        self.max_train_size = self.config.get('max_train_size', 250)
        self.rho = self.config.get('rho', 0.05)
        
        # State
        self.device = torch.device("cpu")
        self.dtype = torch.float64
        self.n_objectives = 2
        
        self.X_train = None
        self.Y_train = None
        self.obj_min = None
        self.obj_max = None
        self.best_solutions = []
        
        # Track failures for fallback
        self.consecutive_failures = 0
        self.max_failures_before_fallback = 5
        
        print(f"CVRP-MOBO initialized: {self.n_customers} customers")
    
    # =========================================================================
    # ENCODING & DECODING
    # =========================================================================
    
    def encode_solution(self, perm: torch.Tensor) -> torch.Tensor:
        """Position encoding: for customer c at position p, encoded[c] = p/(n-1)"""
        n = len(perm)
        encoded = torch.zeros(n, dtype=self.dtype, device=self.device)
        for pos, cust in enumerate(perm.tolist()):
            encoded[int(cust)] = pos / max(n - 1, 1)
        return encoded
    
    def _split_into_routes(self, perm: List[int]) -> List[List[int]]:
        """
        Decode permutation to routes with MAKESPAN-AWARE splitting.
        Creates more routes with fewer customers to reduce makespan.
        """
        # Target makespan (should be below reference point)
        TARGET_MAKESPAN = 3.5  # Below reference of 4
        
        routes = []
        current_route = []
        current_demand = 0
        current_length = 0.0
        last_pos = 0  # depot
        
        for idx in perm:
            cid = idx + 1  # 0-indexed to 1-indexed
            demand = self.customers[cid].demand
            
            # Calculate distance to add this customer
            dist_to_customer = self.distances[last_pos, cid]
            dist_back_to_depot = self.distances[cid, 0]
            
            # Projected route length if we add this customer
            # (current length - return to depot + to customer + new return)
            if current_route:
                old_return = self.distances[last_pos, 0]
                new_length = current_length - old_return + dist_to_customer + dist_back_to_depot
            else:
                new_length = dist_to_customer + dist_back_to_depot
            
            # Check both capacity AND makespan constraints
            capacity_ok = (current_demand + demand <= self.vehicle_capacity)
            makespan_ok = (new_length <= TARGET_MAKESPAN)
            
            if capacity_ok and makespan_ok:
                # Add customer to current route
                current_route.append(cid)
                current_demand += demand
                current_length = new_length
                last_pos = cid
            else:
                # Start a new route
                if current_route:
                    routes.append(current_route)
                current_route = [cid]
                current_demand = demand
                current_length = self.distances[0, cid] + self.distances[cid, 0]
                last_pos = cid
        
        if current_route:
            routes.append(current_route)
        
        return routes

    def _evaluate_perm(self, perm: List[int]) -> Tuple[float, float]:
        """Evaluate a permutation"""
        routes = self._split_into_routes(perm)
        
        total_dist = 0
        route_lengths = []
        
        for route in routes:
            length = 0
            pos = 0
            for cid in route:
                length += self.distances[pos, cid]
                pos = cid
            length += self.distances[pos, 0]
            
            total_dist += length
            route_lengths.append(length)
        
        makespan = max(route_lengths) if route_lengths else 0
        return total_dist, makespan
    
    # =========================================================================
    # GENETIC OPERATORS
    # =========================================================================
    
    # def _create_solution(self) -> torch.Tensor:
    #     return torch.randperm(self.n_customers, device=self.device)
    
    def _create_solution(self) -> torch.Tensor:
        """Create solution with cluster-aware ordering to help makespan"""
        # Sort customers by angle from depot for better clustering
        depot_x, depot_y = 0.5, 0.5  # depot position
        
        angles = []
        for cid in range(1, self.n_customers + 1):
            cx = self.customers[cid].x
            cy = self.customers[cid].y
            angle = np.arctan2(cy - depot_y, cx - depot_x)
            angles.append((angle, cid - 1))  # store 0-indexed
        
        # Sort by angle (creates natural clusters)
        angles.sort()
        
        # Add some randomization
        perm = [cid for _, cid in angles]
        
        # Random swaps for diversity
        for _ in range(self.n_customers // 5):
            i, j = random.sample(range(self.n_customers), 2)
            perm[i], perm[j] = perm[j], perm[i]
        
        return torch.tensor(perm, device=self.device)


    def _tournament_select(self, pop, fitness):
        idx = random.sample(range(len(pop)), min(self.tournament_size, len(pop)))
        return pop[max(idx, key=lambda i: fitness[i])]
    
    def _order_crossover(self, p1, p2):
        n = len(p1)
        p1, p2 = p1.tolist(), p2.tolist()
        start, end = sorted(random.sample(range(n), 2))
        
        child = [-1] * n
        child[start:end] = p1[start:end]
        
        remaining = [c for c in p2 if c not in child[start:end]]
        idx = 0
        for i in range(n):
            if child[i] == -1:
                child[i] = remaining[idx]
                idx += 1
        
        return torch.tensor(child, device=self.device)
    
    def _swap_mutation(self, sol):
        mut = sol.clone()
        i, j = random.sample(range(len(sol)), 2)
        mut[i], mut[j] = mut[j].clone(), mut[i].clone()
        return mut
    
    def _inversion_mutation(self, sol):
        mut = sol.clone()
        i, j = sorted(random.sample(range(len(sol)), 2))
        mut[i:j+1] = torch.flip(mut[i:j+1], [0])
        return mut
    
    # =========================================================================
    # GP MODEL - FIXED with isotropic kernel
    # =========================================================================
    
    def build_model_list(self):
        """Build GP models with ISOTROPIC kernel (not ARD)"""
        # Limit training data
        if self.X_train.shape[0] > self.max_train_size:
            self.X_train = self.X_train[-self.max_train_size:]
            self.Y_train = self.Y_train[-self.max_train_size:]
        
        # Remove duplicates
        self.X_train, self.Y_train = self._remove_duplicates(self.X_train, self.Y_train)
        
        n_points = self.X_train.shape[0]
        print(f"  Building GP with {n_points} points")
        
        X = self.X_train.clone().to(self.device, self.dtype)
        Y = self.Y_train.clone().to(self.device, self.dtype)
        
        # Clean NaN/Inf
        valid = ~(torch.isnan(X).any(1) | torch.isinf(X).any(1) |
                  torch.isnan(Y).any(1) | torch.isinf(Y).any(1))
        X, Y = X[valid], Y[valid]
        n_points = X.shape[0]
        
        if n_points < 3:
            raise RuntimeError(f"Too few training points: {n_points}")
        
        models = []
        
        for obj_idx in range(self.n_objectives):
            Y_obj = Y[:, obj_idx:obj_idx+1]
            
            # Handle near-zero variance
            if Y_obj.std() < 1e-6:
                Y_obj = Y_obj + torch.randn_like(Y_obj) * 1e-4
            
            try:
                # Use sparse GP for larger datasets
                if n_points > 50:
                    n_ind = min(50, n_points)
                    ind_idx = torch.randperm(n_points)[:n_ind]
                    
                    model = SingleTaskVariationalGP(
                        X, Y_obj,
                        inducing_points=X[ind_idx].clone(),
                        outcome_transform=Standardize(m=1)
                    )
                else:
                    model = SingleTaskGP(X, Y_obj, outcome_transform=Standardize(m=1))
                
                model = model.to(self.device)
                
                # KEY FIX: Use ISOTROPIC kernel (single lengthscale)
                # ARD with 50 lengthscales CANNOT be learned from 100-250 points
                model.covar_module = ScaleKernel(
                    MaternKernel(nu=2.5),  # NO ard_num_dims = isotropic
                    outputscale_prior=gpytorch.priors.GammaPrior(2.0, 0.15)
                ).to(self.device)
                
                # Fit model
                self._fit_model(model, Y_obj, n_points)
                models.append(model)
                print(f"    Objective {obj_idx}: OK")
                
            except Exception as e:
                print(f"    Objective {obj_idx} failed: {e}")
                # Fallback
                model = SingleTaskGP(X, Y_obj, outcome_transform=Standardize(m=1))
                model = model.to(self.device)
                models.append(model)
        
        return ModelListGP(*models)
    
    def _fit_model(self, model, Y_obj, n_points):
        """Fit GP model with multiple attempts"""
        for opts in [{'maxiter': 100}, {'maxiter': 50}, {'maxiter': 20}]:
            try:
                if isinstance(model, SingleTaskVariationalGP):
                    mll = VariationalELBO(model.likelihood, model.model, num_data=Y_obj.shape[0])
                else:
                    mll = ExactMarginalLogLikelihood(model.likelihood, model)
                fit_gpytorch_mll(mll, options=opts)
                return
            except:
                continue
    
    def _remove_duplicates(self, X, Y, tol=1e-4):
        if X.shape[0] == 0:
            return X, Y
        keep = [0]
        for i in range(1, X.shape[0]):
            if not any(torch.norm(X[i] - X[j]) < tol for j in keep):
                keep.append(i)
        return X[keep], Y[keep]
    
    # =========================================================================
    # SCALARIZATION - FIXED tensor shapes
    # =========================================================================
    
    def _update_bounds(self, Y):
        if self.obj_min is None:
            self.obj_min = Y.min(0).values.clone()
            self.obj_max = Y.max(0).values.clone()
        else:
            self.obj_min = torch.min(self.obj_min, Y.min(0).values)
            self.obj_max = torch.max(self.obj_max, Y.max(0).values)
    
    def _sample_weights(self, n):
        """Sample weight vectors with consistent tensor shapes"""
        weights = []
        
        def add(w):
            weights.append(torch.tensor(w, device=self.device, dtype=self.dtype))
        
        # Fixed weights for good coverage
        add([0.95, 0.05]); add([0.05, 0.95]); add([0.5, 0.5])
        add([0.8, 0.2]); add([0.2, 0.8]); add([0.7, 0.3]); add([0.3, 0.7])
        add([0.6, 0.4]); add([0.4, 0.6])
        
        # Fill with Dirichlet
        remaining = max(0, n - len(weights))
        if remaining > 0:
            samples = dirichlet.rvs(np.ones(2) * 0.5, size=remaining)
            for s in samples:
                add(s)
        
        return torch.stack(weights[:n])
    
    def _create_scalarization(self, weights):
        """Create scalarization with FIXED tensor shape handling"""
        # Capture current bounds
        obj_min = self.obj_min.clone() if self.obj_min is not None else None
        obj_max = self.obj_max.clone() if self.obj_max is not None else None
        rho = self.rho
        w = weights.clone()
        
        def scalarize(samples, X=None):
            """
            Handle various input shapes from qNEI:
            - [n_samples, batch, q, n_obj] for batched q-point acquisition
            - [n_samples, batch, n_obj] for single-point 
            """
            try:
                ndim = samples.dim()
                
                # Normalize
                if obj_min is not None and obj_max is not None:
                    rng = (obj_max - obj_min) + 1e-8
                    
                    if ndim == 4:
                        samples = (samples - obj_min.view(1, 1, 1, -1)) / rng.view(1, 1, 1, -1)
                        w_exp = w.view(1, 1, 1, -1)
                    elif ndim == 3:
                        samples = (samples - obj_min.view(1, 1, -1)) / rng.view(1, 1, -1)
                        w_exp = w.view(1, 1, -1)
                    else:
                        samples = (samples - obj_min) / rng
                        w_exp = w
                else:
                    if ndim == 4:
                        w_exp = w.view(1, 1, 1, -1)
                    elif ndim == 3:
                        w_exp = w.view(1, 1, -1)
                    else:
                        w_exp = w
                
                # Weighted Chebyshev
                weighted = samples * w_exp
                max_term = weighted.max(dim=-1).values
                sum_term = rho * weighted.sum(dim=-1)
                
                result = max_term + sum_term
                
                # Clean NaN/Inf
                result = torch.where(
                    torch.isnan(result) | torch.isinf(result),
                    torch.zeros_like(result),
                    result
                )
                
                return result
                
            except Exception as e:
                print(f"    Scalarization error: {e}")
                traceback.print_exc()
                # Return zeros on failure
                return torch.zeros(samples.shape[:-1], device=samples.device, dtype=samples.dtype)
        
        return GenericMCObjective(scalarize)
    
    # =========================================================================
    # ACQUISITION OPTIMIZATION - FIXED
    # =========================================================================
    
    def optimize_acquisition(self, acq_func, weights):
        """GA-based optimization with proper shape handling"""
        pop = [self._create_solution() for _ in range(self.pop_size)]
        
        # Seed with good solutions
        for sol, _ in self.best_solutions[:5]:
            try:
                if isinstance(sol, list) and len(sol) > 0 and isinstance(sol[0], list):
                    flat = [c - 1 for route in sol for c in route]
                    if len(flat) == self.n_customers:
                        pop[-1] = torch.tensor(flat, device=self.device)
            except:
                pass
        
        encoded = [self.encode_solution(s) for s in pop]
        fitness = self._batch_acq_eval(acq_func, encoded)
        
        for gen in range(self.n_generations):
            offspring = []
            offspring_enc = []
            
            for _ in range(self.pop_size):
                p1 = self._tournament_select(pop, fitness)
                p2 = self._tournament_select(pop, fitness)
                
                child = self._order_crossover(p1, p2) if random.random() < self.crossover_prob else p1.clone()
                
                # if random.random() < self.mutation_prob:
                #     child = self._swap_mutation(child) if random.random() < 0.5 else self._inversion_mutation(child)
                if random.random() < self.mutation_prob:
                    if random.random() < 0.3:  # 30% balance mutation
                        child = self._balance_mutation(child)
                    elif random.random() < 0.5:
                        child = self._swap_mutation(child)
                    else:
                        child = self._inversion_mutation(child)
                
                offspring.append(child)
                offspring_enc.append(self.encode_solution(child))
            
            offspring_fit = self._batch_acq_eval(acq_func, offspring_enc)
            
            # Selection
            combined = pop + offspring
            combined_fit = fitness + offspring_fit
            combined_enc = encoded + offspring_enc
            
            sorted_idx = np.argsort(combined_fit)[::-1]
            
            pop = [combined[i] for i in sorted_idx[:self.pop_size]]
            fitness = [combined_fit[i] for i in sorted_idx[:self.pop_size]]
            encoded = [combined_enc[i] for i in sorted_idx[:self.pop_size]]
        
        return pop[np.argmax(fitness)]
    
    def _batch_acq_eval(self, acq_func, encoded):
        """
        Batch evaluate acquisition function.
        
        FIXED: qNEI expects shape [batch, q, d] where q is the number of points
        to evaluate jointly. For single-point evaluation, q=1.
        """
        try:
            # Stack: [batch, d] -> [batch, 1, d] (q=1)
            batch = torch.stack(encoded).unsqueeze(1)
            
            with torch.no_grad():
                vals = acq_func(batch)
            
            vals = torch.where(
                torch.isnan(vals) | torch.isinf(vals),
                torch.full_like(vals, -1e10),
                vals
            )
            
            return vals.squeeze().tolist() if vals.numel() > 1 else [float(vals)]
            
        except Exception as e:
            # Full error for debugging
            print(f"    Batch acq error: {type(e).__name__}: {e}")
            # Fallback to individual
            return [self._single_acq_eval(acq_func, e) for e in encoded]
    
    def _single_acq_eval(self, acq_func, enc):
        try:
            x = enc.unsqueeze(0).unsqueeze(0)  # [1, 1, d]
            with torch.no_grad():
                val = acq_func(x)
            return -1e10 if (torch.isnan(val) or torch.isinf(val)) else float(val)
        except:
            return -1e10
    
    # =========================================================================
    # PARETO FRONT
    # =========================================================================
    
    def _add_if_nondominated(self, sol, obj):
        obj = np.array(obj)
        
        if not self.best_solutions:
            self.best_solutions.append((sol, obj))
            return
        
        dominated = False
        to_remove = []
        
        for i, (_, existing) in enumerate(self.best_solutions):
            # Minimization: A dominates B if A <= B in all and A < B in at least one
            if all(obj <= existing) and any(obj < existing):
                to_remove.append(i)
            elif all(existing <= obj) and any(existing < obj):
                dominated = True
                break
        
        if not dominated:
            self.best_solutions = [s for i, s in enumerate(self.best_solutions) if i not in to_remove]
            self.best_solutions.append((sol, obj))
    
    def _balance_mutation(self, sol):
        """Mutation that specifically targets makespan reduction"""
        # Decode to routes
        routes = self._split_into_routes(sol.tolist())
        
        if len(routes) < 2:
            return self._swap_mutation(sol)
        
        # Find longest and shortest routes
        lengths = []
        for route in routes:
            length = self.distances[0, route[0]]
            for i in range(len(route) - 1):
                length += self.distances[route[i], route[i+1]]
            length += self.distances[route[-1], 0]
            lengths.append(length)
        
        longest_idx = np.argmax(lengths)
        
        # Swap a customer from longest route with random position
        longest_route = routes[longest_idx]
        if longest_route:
            # Find position of a customer from longest route in permutation
            cust_to_move = random.choice(longest_route) - 1  # Convert to 0-indexed
            
            mut = sol.clone()
            pos1 = (mut == cust_to_move).nonzero(as_tuple=True)[0]
            if len(pos1) > 0:
                pos1 = pos1[0].item()
                pos2 = random.randint(0, len(sol) - 1)
                mut[pos1], mut[pos2] = mut[pos2].clone(), mut[pos1].clone()
                return mut
        
        return self._swap_mutation(sol)


    # =========================================================================
    # MAIN LOOP
    # =========================================================================
    
    def run(self):
        """Run optimization and return decoded routes"""
        print(f"\nRunning CVRP-MOBO: {self.n_customers} customers")

        # Initial sampling
        X_init, Y_init = [], []
        
        for _ in range(self.n_initial):
            perm = self._create_solution()
            obj = self._evaluate_perm(perm.tolist())
            
            if not any(np.isnan(o) or np.isinf(o) for o in obj):
                X_init.append(self.encode_solution(perm))
                Y_init.append(torch.tensor(obj, device=self.device, dtype=self.dtype))
                
                routes = self._split_into_routes(perm.tolist())
                self._add_if_nondominated(routes, obj)
        
        if len(X_init) < 3:
            raise ValueError("Could not generate enough initial solutions")
        
        self.X_train = torch.stack(X_init)
        self.Y_train = torch.stack(Y_init)
        self._update_bounds(self.Y_train)
        
        print(f"Initial: {len(X_init)} solutions, Pareto: {len(self.best_solutions)}")
        print(f"Obj ranges: dist=[{self.Y_train[:, 0].min():.1f}, {self.Y_train[:, 0].max():.1f}], "
              f"make=[{self.Y_train[:, 1].min():.2f}, {self.Y_train[:, 1].max():.2f}]")
        
        # Main loop
        for it in range(self.n_iterations):
            print(f"Iter {it+1}/{self.n_iterations}", end=" ")
            
            try:
                model = self.build_model_list()
                weights_batch = self._sample_weights(self.q)
                
                new_X, new_Y = [], []
                successes = 0
                
                for w_idx, weights in enumerate(weights_batch):
                    try:
                        objective = self._create_scalarization(weights)
                        
                        acq = qNoisyExpectedImprovement(
                            model=model,
                            objective=objective,
                            X_baseline=self.X_train,
                            sampler=SobolQMCNormalSampler(sample_shape=torch.Size([64]))
                        )
                        
                        sol = self.optimize_acquisition(acq, weights)
                        obj = self._evaluate_perm(sol.tolist())
                        
                        if not any(np.isnan(o) or np.isinf(o) for o in obj):
                            new_X.append(self.encode_solution(sol))
                            new_Y.append(torch.tensor(obj, device=self.device, dtype=self.dtype))
                            
                            routes = self._split_into_routes(sol.tolist())
                            self._add_if_nondominated(routes, obj)
                            successes += 1
                            
                    except Exception as e:
                        print(f"\n    Weight {w_idx} error: {type(e).__name__}: {e}")
                        traceback.print_exc()
                
                # Fallback if all acquisitions fail
                if successes == 0:
                    self.consecutive_failures += 1
                    print(f"(0 successes, fail={self.consecutive_failures})", end=" ")
                    
                    if self.consecutive_failures >= self.max_failures_before_fallback:
                        print("\n    Fallback to random...")
                        for _ in range(self.q):
                            sol = self._create_solution()
                            obj = self._evaluate_perm(sol.tolist())
                            if not any(np.isnan(o) or np.isinf(o) for o in obj):
                                new_X.append(self.encode_solution(sol))
                                new_Y.append(torch.tensor(obj, device=self.device, dtype=self.dtype))
                                routes = self._split_into_routes(sol.tolist())
                                self._add_if_nondominated(routes, obj)
                        self.consecutive_failures = 0
                else:
                    self.consecutive_failures = 0
                
                if new_X:
                    self.X_train = torch.vstack([self.X_train, torch.stack(new_X)])
                    self.Y_train = torch.vstack([self.Y_train, torch.stack(new_Y)])
                    self._update_bounds(torch.stack(new_Y))
                
                print(f"| +{len(new_X)} | Train: {self.X_train.shape[0]}, Pareto: {len(self.best_solutions)}")
                
            except Exception as e:
                print(f"| Error: {type(e).__name__}: {e}")
                traceback.print_exc()
        
        # Return decoded routes
        return [(sol, obj.tolist()) for sol, obj in self.best_solutions]
    
    @property
    def total_evaluations(self):
        return self.X_train.shape[0] if self.X_train is not None else 0




def test_tritsp():
    """Test MOBO-qParEGO specifically on TriTSP (3-objective TSP)"""
    import math
    
    n_runs = 1
    tsp_size = 20
    
    # Initialize evaluator with 3D reference point
    evaluator = MOCOEvaluator(reference_point=(30.0, 30.0, 30.0))  # 3D reference point
    
    # Configure MOBO parameters optimized for 3-objective problems
    config = {
        'n_initial': 30,  # More initial samples for 3 objectives
        'n_iterations': 15,
        'q': int(6 * (1 + math.log(tsp_size / 20))),  # More weight vectors for 3D
        'pop_size': 80,
        'n_generations': 50,
        'crossover_prob': 0.9,
        'mutation_prob': 0.2,
        'tournament_size': 5,
        'matern_nu': 2.5,
        'use_sparse_gp': True,
        'model_rebuild_interval': 3
    }
    
    print("\nTesting MOBO-qParEGO on Tri-Objective TSP:")
    print("=========================================")
    
    tsp_params = {'n_cities': tsp_size}
    
    for run in range(n_runs):
        print(f"\nRun {run + 1}/{n_runs}")
        
        result = evaluator.evaluate_algorithm(
            algorithm_class=MOBOqParEGO,
            problem_class=TriObjectiveTSP,
            algorithm_name="MOBO-qParEGO-TriTSP",
            parameters={'config': config},
            problem_params=tsp_params,
            num_runs=1
        )
        
        print(f"Run {run + 1} Results:")
        print(f"Runtime: {result.runtime:.2f} seconds")
        print(f"Hypervolume: {result.hypervolume:.4f}")
        print(f"Non-dominated solutions: {result.num_nondominated}")
    
    # Generate summary report
    evaluator.save_all_results("mobo_tritsp_results.json")
    evaluator.generate_report()

if __name__ == "__main__":
    # Test the fixed version on TriTSP
    test_tritsp()