"""Bayesian Optimization with Probabilistic Reparameterization (BOPR)
Implementation following Daulton et al. 2022

1. TRUE probabilistic reparameterization with REINFORCE gradients
2. Plackett-Luce distribution for permutations (TSP)
3. Bernoulli distribution for binary vectors (Knapsack)
4. Multiple MC samples per optimization step
5. Multiple random restarts
6. Proper variance reduction with baseline
"""

import torch
import gpytorch
from botorch.models import SingleTaskGP, ModelListGP
from botorch.models.approximate_gp import SingleTaskVariationalGP
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
from botorch.acquisition.multi_objective.logei import qLogNoisyExpectedHypervolumeImprovement
from botorch.utils.multi_objective.box_decompositions.dominated import DominatedPartitioning
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.transforms import normalize
from botorch.sampling.normal import SobolQMCNormalSampler
from typing import Tuple, List, Optional, Union
import numpy as np
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="linear_operator")


###############################################################################
# Distributions for Probabilistic Reparameterization
###############################################################################

class PlackettLuceDistribution:
    """
    Plackett-Luce distribution for sampling permutations.
    This is the CORRECT distribution for permutation problems in BOPR.
    """
    
    def __init__(self, logits: torch.Tensor, temperature: float = 1.0):
        """
        Args:
            logits: Preference scores [..., n]
            temperature: Temperature for sampling (higher = more random)
        """
        self.logits = logits
        self.temperature = temperature
        
    def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """
        Sample permutations using the Gumbel-Max trick for Plackett-Luce.

        Returns:
            Sampled permutations [..., n]
        """
        # Add Gumbel noise
        gumbel = -torch.log(-torch.log(
            torch.rand(sample_shape + self.logits.shape, device=self.logits.device) + 1e-10
        ) + 1e-10)

        # Perturb logits with temperature (correct formula)
        perturbed = self.logits / self.temperature + gumbel

        # Argsort gives us the permutation
        return torch.argsort(perturbed, dim=-1, descending=True)
    
    def log_prob(self, permutation: torch.Tensor) -> torch.Tensor:
        """
        Compute log probability of a permutation under Plackett-Luce.
        
        Args:
            permutation: Permutation tensor [..., n]
            
        Returns:
            Log probability [...]
        """
        batch_shape = permutation.shape[:-1]
        n = permutation.shape[-1]
        
        # Expand logits to match permutation batch shape if needed
        if self.logits.dim() == 1:
            logits_expanded = self.logits.unsqueeze(0).expand(*batch_shape, n)
        else:
            logits_expanded = self.logits
        
        # Flatten for processing
        logits_flat = logits_expanded.reshape(-1, n)
        perm_flat = permutation.reshape(-1, n)
        batch_size = logits_flat.shape[0]
        
        log_prob = torch.zeros(batch_size, device=permutation.device)
        
        # Compute log probability for each position
        for i in range(n):
            # Create a mask for already selected elements
            mask = torch.ones_like(logits_flat, dtype=torch.bool)
            
            # Mask already selected elements
            for j in range(i):
                selected_indices = perm_flat[:, j]
                batch_idx = torch.arange(batch_size, device=permutation.device)
                mask[batch_idx, selected_indices] = False
            
            # Create masked logits (set masked elements to -inf)
            masked_logits = logits_flat.clone()
            masked_logits[~mask] = -float('inf')
            
            # Compute log softmax over remaining elements
            log_probs_i = torch.log_softmax(masked_logits / self.temperature, dim=-1)
            
            # Get probability of selecting the i-th element in permutation
            selected_i = perm_flat[:, i]
            batch_idx = torch.arange(batch_size, device=permutation.device)
            log_prob += log_probs_i[batch_idx, selected_i]
        
        return log_prob.reshape(batch_shape)


# class BernoulliDistribution:
#     """
#     Independent Bernoulli distribution for binary vectors (Knapsack).
#     """
    
#     def __init__(self, logits: torch.Tensor, temperature: float = 1.0):
#         """
#         Args:
#             logits: Logits for each binary decision [..., n]
#             temperature: Temperature for sampling
#         """
#         self.logits = logits
#         self.temperature = temperature
        
#     def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
#         """
#         Sample binary vectors using Gumbel-Softmax trick.

#         Returns:
#             Binary vectors [..., n]
#         """
#         # Use Gumbel-Sigmoid for binary sampling
#         gumbel = -torch.log(-torch.log(
#             torch.rand(sample_shape + self.logits.shape, device=self.logits.device) + 1e-10
#         ) + 1e-10)

#         # Perturb and threshold (correct formula)
#         perturbed = self.logits / self.temperature + gumbel
#         return (perturbed > 0).float()
    
#     def log_prob(self, binary: torch.Tensor) -> torch.Tensor:
#         """
#         Compute log probability of binary vector.
        
#         Args:
#             binary: Binary tensor [..., n]
            
#         Returns:
#             Log probability [...]
#         """
#         # Bernoulli log probability
#         probs = torch.sigmoid(self.logits / self.temperature)
#         log_prob = binary * torch.log(probs + 1e-10) + (1 - binary) * torch.log(1 - probs + 1e-10)
#         return log_prob.sum(dim=-1)

class BernoulliDistribution:
    """
    Independent Bernoulli distribution for binary vectors (Knapsack).
    """

    def __init__(self, logits: torch.Tensor, temperature: float = 1.0):
        self.logits = logits
        self.temperature = temperature

    def sample(self, sample_shape=torch.Size()):
        """Sample from Bernoulli(sigmoid(logits/temperature))."""
        probs = torch.sigmoid(self.logits / self.temperature)
        probs = probs.expand(sample_shape + self.logits.shape)
        return torch.bernoulli(probs)

    def log_prob(self, binary):
        """Exact Bernoulli log-probability."""
        probs = torch.sigmoid(self.logits / self.temperature)
        log_prob = binary * torch.log(probs + 1e-10) + (1-binary) * torch.log(1-probs + 1e-10)
        return log_prob.sum(-1)



###############################################################################
# Main BOPR Implementation
###############################################################################

class ProperBOPR:
    """
    PROPER Bayesian Optimization with Probabilistic Reparameterization.
    This actually implements the algorithm from Daulton et al. 2022!
    """

    def __init__(
        self,
        problem,
        n_initial: int = 50,
        n_iterations: int = 30,
        mc_samples: int = 64,
        lr: float = 0.05,
        temperature: float = 1.0,
        n_restarts: int = 5,
        reference_point: Optional[Tuple[float, ...]] = None,
        device: Optional[str] = None,
        sparse_gp: bool = False,
        inducing_points: int = 100
    ):
        """
        Initialize PROPER BOPR.

        Args:
            problem: Problem instance (BiObjectiveTSP, TriObjectiveTSP, or MultiObjectiveKnapsack)
            n_initial: Number of initial random samples
            n_iterations: Number of BO iterations
            mc_samples: MC samples for REINFORCE gradient estimation (ACTUALLY USED NOW!)
            lr: Learning rate for acquisition optimization
            temperature: Temperature for probabilistic reparameterization
            n_restarts: Number of random restarts for acquisition optimization
            reference_point: Reference point for hypervolume
            device: Device to use (cuda or cpu)
            sparse_gp: Whether to use sparse variational GPs
            inducing_points: Number of inducing points for sparse GPs
        """
        self.problem = problem
        self.n_initial = n_initial
        self.n_iterations = n_iterations
        self.mc_samples = mc_samples
        self.lr = lr
        self.temperature = temperature
        self.n_restarts = n_restarts
        self.sparse_gp = sparse_gp
        self.inducing_points = inducing_points

        # Auto-detect problem type
        problem_class_name = problem.__class__.__name__
        # if 'TSP' in problem_class_name:
        #     self.problem_type = 'tsp'
        #     self.solution_dim = problem.n_cities
        #     self.problem_name = f"{problem.num_objectives}-obj TSP"
        if 'TSP' in problem_class_name:
            self.problem_type = 'tsp'
            self.n_cities = problem.n_cities
            self.solution_dim = self.n_cities * self.n_cities
            self.problem_name = f"{problem.num_objectives}-obj TSP"
            # In the TSP block, add:
            self.perm_dim = self.n_cities
        elif 'Knapsack' in problem_class_name:
            self.problem_type = 'knapsack'
            self.solution_dim = problem.n_items
            self.capacity = problem.capacity
            self.weights = torch.tensor(problem.weights, dtype=torch.float64)
            self.problem_name = f"{problem.num_objectives}-obj Knapsack"
        elif 'CVRP' in problem_class_name:
            self.problem_type = 'cvrp'
            self.n_customers = problem.n_customers
            self.solution_dim = self.n_customers
            self.perm_dim = self.n_customers
            self.vehicle_capacity = problem.vehicle_capacity
            self.customers = problem.customers
            self.problem_name = f"{problem.num_objectives}-obj CVRP"
            
        else:
            raise ValueError(f"Unknown problem class: {problem_class_name}")

        # Orientation: treat TSP as minimization, Knapsack as maximization
        self.minimization_problem = (self.problem_type in ['tsp', 'cvrp']) # (self.problem_type == 'tsp')


        # Get number of objectives
        self.n_objectives = problem.num_objectives

        # Set device
        if device is None:
            self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        # Data storage
        self.X_observed = []  # Observed solutions (as encodings)
        self.Y_observed = []  # Observed objectives

        # Reference point
        # Reference point: external (MOCO) is in raw space (minimization for TSP)
        self.external_reference_point = reference_point
        self.fixed_reference_point = reference_point is not None

        # Internal reference point used by BoTorch (always maximization)
        if reference_point is not None and self.minimization_problem:
            # Flip sign for internal maximization orientation
            ref_tensor = torch.tensor(reference_point, dtype=torch.float64)
            self.reference_point = tuple((-ref_tensor).tolist())
        else:
            # Either maximization problem already or no ref yet
            self.reference_point = reference_point


        # GP models
        self.gp_models = []

        # REINFORCE baseline for variance reduction
        self.baseline = 0.0
        self.baseline_decay = 0.9

        print(f"\n{'='*70}")
        print("PROPER BOPR IMPLEMENTATION (Daulton et al. 2022)")
        print(f"{'='*70}")
        print(f"Problem: {self.problem_name}")
        print(f"Solution dimension: {self.solution_dim}")
        print(f"Number of objectives: {self.n_objectives}")
        print(f"Device: {self.device}")
        print(f"n_initial: {n_initial}, n_iterations: {n_iterations}")
        print(f"MC samples: {mc_samples}, n_restarts: {n_restarts}")
        print(f"Temperature: {temperature}")
        print(f"{'='*70}\n")

    # def _encode_solution(self, solution: Union[torch.Tensor, List]) -> torch.Tensor:
    #     """Encode solution as continuous vector for GP."""
    #     if isinstance(solution, list):
    #         solution = torch.tensor(solution, dtype=torch.float64)

    #     if self.problem_type == 'tsp':
    #         # Normalize permutation values to [0, 1]
    #         return solution.double() / (self.solution_dim - 1)
    #     elif self.problem_type == 'knapsack':
    #         # Binary vector is already in {0, 1}
    #         return solution.double()

    ### latest
    # def _encode_solution(self, solution):
    #     """
    #     Encode a TSP tour into an edge-incidence vector of shape (n*n,).
    #     Supports:
    #         - hard permutations (int list or tensor)
    #         - soft continuous permutations from PR sampling (float tensor)
    #     """

    #     if self.problem_type != 'tsp':
    #         # Knapsack: identity encoding
    #         if isinstance(solution, list):
    #             return torch.tensor(solution, dtype=torch.float64)
    #         return solution.double()

    #     # ---------------------------
    #     # Convert to clean permutation
    #     # ---------------------------
    #     if isinstance(solution, list):
    #         perm = torch.tensor(solution, dtype=torch.long)
        
    #     elif isinstance(solution, torch.Tensor):
    #         if solution.dtype in (torch.float32, torch.float64):
    #             # soft continuous vector → convert to permutation
    #             perm = torch.argsort(solution)  # stable & deterministic
    #         else:
    #             # already long/int tensor
    #             perm = solution.long()
        
    #     else:
    #         raise ValueError("Unknown solution type for encoding")

    #     n = self.problem.n_cities
    #     assert perm.numel() == n, f"Permutation length {perm.numel()} != n({n})"

    #     # ---------------------------
    #     # Build edge incidence matrix
    #     # ---------------------------
    #     E = torch.zeros((n, n), dtype=torch.float64)

    #     # For each i → successor j
    #     for idx in range(n):
    #         i = perm[idx].item()
    #         j = perm[(idx + 1) % n].item()   # SAFE wrap-around
    #         E[i, j] = 1.0

    #     # Flatten to vector (n*n,)
    #     return E.reshape(-1)

    def _encode_solution(self, solution):
        """
        Encode a TSP tour into an edge-incidence vector of shape (n*n,).
        Supports:
            - hard permutations (int list or tensor)
            - soft continuous permutations from PR sampling (float tensor)
        """

        # CVRP: flatten routes to permutation, encode as positions
        if self.problem_type == 'cvrp':
            if isinstance(solution, list):
                if len(solution) > 0 and isinstance(solution[0], list):
                    perm = []
                    for route in solution:
                        perm.extend(route)
                else:
                    perm = solution
            else:
                perm = solution.tolist()
            
            encoding = torch.zeros(self.n_customers, dtype=torch.float64)
            for pos, cid in enumerate(perm):
                encoding[cid - 1] = pos / max(1, self.n_customers - 1)
            return encoding

        # Knapsack: identity encoding
        if self.problem_type == 'knapsack':
            if isinstance(solution, list):
                return torch.tensor(solution, dtype=torch.float64)
            return solution.double()

        # ---------------------------
        # TSP: Convert to clean permutation
        # ---------------------------
        if isinstance(solution, list):
            perm = torch.tensor(solution, dtype=torch.long)
        elif isinstance(solution, torch.Tensor):
            if solution.dtype in (torch.float32, torch.float64):
                perm = torch.argsort(solution)
            else:
                perm = solution.long()
        else:
            raise ValueError("Unknown solution type for encoding")

        n = self.problem.n_cities
        assert perm.numel() == n, f"Permutation length {perm.numel()} != n({n})"

        # ---------------------------
        # Build edge incidence matrix
        # ---------------------------
        E = torch.zeros((n, n), dtype=torch.float64)

        for idx in range(n):
            i = perm[idx].item()
            j = perm[(idx + 1) % n].item()
            E[i, j] = 1.0

        return E.reshape(-1)

    # def _decode_solution(self, encoded: torch.Tensor) -> Union[torch.Tensor, List[int]]:
    #     """Decode continuous vector back to valid solution."""
    #     if self.problem_type == 'tsp':
    #         # Argsort to convert continuous to permutation
    #         return torch.argsort(encoded)
    #     elif self.problem_type == 'knapsack':
    #         # Threshold to binary
    #         binary = (encoded > 0.5).long()
    #         # Enforce capacity constraint
    #         return self._enforce_capacity_constraint(binary)

    def _decode_solution(self, encoded):
        """
        Decode an edge-incidence vector back into a valid TSP tour.
        Ensures:
            - one outgoing edge
            - one incoming edge
            - single Hamiltonian cycle
        """
        
        # if self.problem_type != 'tsp':
        #     # Knapsack decode
        #     binary = (encoded > 0.5).long()
        #     return self._enforce_capacity_constraint(binary)
        if self.problem_type == 'cvrp':
            perm_0indexed = torch.argsort(encoded)
            customer_order = (perm_0indexed + 1).tolist()
            return self._split_into_routes(customer_order)

        if self.problem_type == 'knapsack':
            binary = (encoded > 0.5).long()
            return self._enforce_capacity_constraint(binary)

        # ----------------------------
        # 1. Reshape edge vector
        # ----------------------------
        n = self.problem.n_cities
        E = encoded.reshape(n, n)

        # Make edges non-negative
        E = torch.relu(E)

        # Remove self-loops
        E = E.clone()
        E[range(n), range(n)] = -1e9

        # ----------------------------
        # 2. Solve assignment problem: 
        #    choose permutation π minimizing -E[i, π(i)]
        # ----------------------------
        # This ensures bijection (1 outgoing, 1 incoming)
        cost_matrix = (-E).detach().cpu().numpy()

        from scipy.optimize import linear_sum_assignment
        row_ind, col_ind = linear_sum_assignment(cost_matrix)

        # successors[i] = j meaning i -> j
        successors = {int(i): int(j) for i, j in zip(row_ind, col_ind)}

        # ----------------------------
        # 3. Extract cycles
        # ----------------------------
        cycles = []
        visited = set()

        for start in range(n):
            if start in visited:
                continue
            cycle = []
            current = start
            while current not in visited:
                visited.add(current)
                cycle.append(current)
                current = successors[current]
            cycles.append(cycle)

        # ----------------------------
        # 4. If only one cycle → valid tour!
        # ----------------------------
        if len(cycles) == 1:
            return torch.tensor(cycles[0], dtype=torch.long)

        # ----------------------------
        # 5. Merge subtours into a single cycle
        # ----------------------------
        # We use a deterministic rule:
        # Connect end of cycle k to start of cycle k+1
        merged_tour = []
        for cycle in cycles:
            merged_tour.extend(cycle)

        # This produces a Hamiltonian path; close it into cycle
        # by shifting nodes (simple rotational fix)
        return torch.tensor(merged_tour, dtype=torch.long)

    def _enforce_capacity_constraint(self, binary_solution: torch.Tensor) -> torch.Tensor:
        """Enforce knapsack capacity constraint."""
        if self.problem_type != 'knapsack':
            return binary_solution

        binary = binary_solution.clone()
        total_weight = (binary * self.weights.to(binary.device)).sum()

        if total_weight <= self.capacity:
            return binary

        # Remove items until capacity is satisfied
        selected_indices = torch.where(binary == 1)[0]
        while total_weight > self.capacity and len(selected_indices) > 0:
            idx_to_remove = selected_indices[torch.randint(len(selected_indices), (1,))]
            binary[idx_to_remove] = 0
            total_weight = (binary * self.weights.to(binary.device)).sum()
            selected_indices = torch.where(binary == 1)[0]

        return binary

    def _initialize_random_samples(self):
        """Generate initial random solutions."""
        print(f"Generating {self.n_initial} initial random samples...")

        for i in range(self.n_initial):
            solution = self.problem.random_solution()
            objectives_tuple = self.problem.evaluate(solution)
            raw_obj = torch.tensor(objectives_tuple, dtype=torch.float64)

            # Internal orientation: maximization
            if self.minimization_problem:
                objectives = -raw_obj
            else:
                objectives = raw_obj

            encoded = self._encode_solution(solution)
            self.X_observed.append(encoded.cpu())
            self.Y_observed.append(objectives.cpu())


            if (i + 1) % 10 == 0:
                print(f"  Sampled {i + 1}/{self.n_initial}")

        self.X_tensor = torch.stack(self.X_observed)
        self.Y_tensor = torch.stack(self.Y_observed)

        # Compute internal reference point if not provided
        if self.reference_point is None:
            # In maximization space, ref must be WORSE than all points (<= min(Y))
            y_min = self.Y_tensor.min(dim=0)[0]
            # small margin below the min
            self.reference_point = tuple((y_min - 0.1 * y_min.abs()).tolist())
            print(f"Computed internal reference point: {self.reference_point}")
        else:
            print(f"Using fixed internal reference point: {self.reference_point}")

        # Print initial stats
        pareto_mask = is_non_dominated(self.Y_tensor)
        n_pareto = pareto_mask.sum().item()
        hv = self._compute_hypervolume(self.Y_tensor[pareto_mask])
        print(f"Initial Pareto front: {n_pareto} solutions, HV: {hv:.4f}")

    def _fit_gp_models(self):
        """Fit GP models to observed data."""
        # Normalize inputs
        X_norm = normalize(self.X_tensor, bounds=torch.stack([
            torch.zeros(self.solution_dim),
            torch.ones(self.solution_dim)
        ]))

        # Standardize outputs
        Y_mean = self.Y_tensor.mean(dim=0)
        Y_std = self.Y_tensor.std(dim=0)
        Y_std = torch.where(Y_std < 1e-6, torch.ones_like(Y_std), Y_std)
        Y_norm = (self.Y_tensor - Y_mean) / Y_std

        X_norm = X_norm.to(self.device)
        Y_norm = Y_norm.to(self.device)

        self.gp_models = []

        if self.sparse_gp:
            n_data = X_norm.shape[0]
            n_inducing = min(self.inducing_points, n_data)
            inducing_indices = torch.randperm(n_data)[:n_inducing]
            inducing_points = X_norm[inducing_indices]

            for obj_idx in range(self.n_objectives):
                model = SingleTaskVariationalGP(
                    X_norm, Y_norm[:, obj_idx:obj_idx+1],
                    inducing_points=inducing_points
                )
                mll = VariationalELBO(model.likelihood, model.model, num_data=n_data)
                fit_gpytorch_mll(mll)
                self.gp_models.append(model)
            print(f"  Sparse GP: {n_inducing}/{n_data} inducing points")
        else:
            for obj_idx in range(self.n_objectives):
                model = SingleTaskGP(X_norm, Y_norm[:, obj_idx:obj_idx+1])
                mll = ExactMarginalLogLikelihood(model.likelihood, model)
                fit_gpytorch_mll(mll)
                self.gp_models.append(model)
            print(f"  Exact GP: {X_norm.shape[0]} data points")

        self.X_bounds = torch.stack([
            torch.zeros(self.solution_dim),
            torch.ones(self.solution_dim)
        ]).to(self.device)
        self.Y_mean = Y_mean.to(self.device)
        self.Y_std = Y_std.to(self.device)

    def _compute_hypervolume(self, pareto_front: torch.Tensor) -> float:
        """Compute hypervolume."""
        if len(pareto_front) == 0:
            return 0.0
        ref_point = torch.tensor(self.reference_point, dtype=pareto_front.dtype)
        bd = DominatedPartitioning(ref_point=ref_point, Y=pareto_front)
        return bd.compute_hypervolume().item()

    def _create_acquisition_function(self):
        """Create qNEHVI acquisition function."""
        pareto_mask = is_non_dominated(self.Y_tensor)
        pareto_Y = self.Y_tensor[pareto_mask].to(self.device)

        ref_point = torch.tensor(
            self.reference_point,
            dtype=torch.float64,
            device=self.device
        )

        sampler = SobolQMCNormalSampler(sample_shape=torch.Size([self.mc_samples]))
        model_list = ModelListGP(*self.gp_models)

        # Use qLogNEHVI for better numerical stability (as recommended by BoTorch)
        acq_function = qLogNoisyExpectedHypervolumeImprovement(
            model=model_list,
            ref_point=ref_point.tolist(),
            X_baseline=normalize(self.X_tensor.to(self.device), bounds=self.X_bounds),
            sampler=sampler,
            prune_baseline=True
        )

        return acq_function

    # def _optimize_acquisition_pr(self, acq_function):
    #     """
    #     THE KEY METHOD: Optimize acquisition using TRUE probabilistic reparameterization.
    #     This implements the actual BOPR algorithm!
    #     """
    #     best_solution = None
    #     best_acq_value = float('-inf')

    #     # Multiple random restarts (as in paper!)
    #     for restart in range(self.n_restarts):
    #         # Initialize theta (continuous parameters for distribution)
    #         theta = torch.nn.Parameter(
    #             torch.randn(self.solution_dim, device=self.device) * 0.1
    #         )
    #         optimizer = torch.optim.Adam([theta], lr=self.lr)

    #         n_steps = 50
    #         local_best_value = float('-inf')
    #         local_best_solution = None

    #         for step in range(n_steps):
    #             optimizer.zero_grad()

    #             # Create distribution from theta
    #             if self.problem_type == 'tsp':
    #                 distribution = PlackettLuceDistribution(theta, self.temperature)
    #             else:  # knapsack
    #                 distribution = BernoulliDistribution(theta, self.temperature)

    #             # Sample M solutions from the distribution (KEY: Multiple MC samples!)
    #             with torch.no_grad():
    #                 samples = distribution.sample(torch.Size([self.mc_samples]))  # [M, n]

    #                 # Encode samples for GP
    #                 if self.problem_type == 'tsp':
    #                     encoded_samples = samples.float() #/ (self.solution_dim - 1)
    #                 else:
    #                     encoded_samples = samples.float()

    #                 # Normalize for GP
    #                 X_candidates = normalize(encoded_samples, bounds=self.X_bounds)  # [M, n]
    #                 X_candidates = X_candidates.unsqueeze(1)  # [M, 1, n] for q=1

    #                 # Evaluate acquisition on all samples
    #                 acq_values = acq_function(X_candidates)  # [M]

    #             # Update baseline (moving average for variance reduction)
    #             current_mean = acq_values.mean().item()
    #             self.baseline = self.baseline * self.baseline_decay + current_mean * (1 - self.baseline_decay)

    #             # Compute advantages (centered rewards)
    #             advantages = acq_values - self.baseline

    #             # Compute log probabilities (THIS NEEDS GRADIENTS!)
    #             log_probs = distribution.log_prob(samples)  # [M]

    #             # REINFORCE gradient: E[(α(x) - b) * ∇log p_θ(x)]
    #             # We use the reparameterization trick: create surrogate loss
    #             surrogate_loss = -(advantages.detach() * log_probs).mean()

    #             # Backprop through log_probs to get gradients w.r.t. theta
    #             surrogate_loss.backward()

    #             # Gradient clipping for stability
    #             torch.nn.utils.clip_grad_norm_([theta], max_norm=1.0)

    #             optimizer.step()

    #             # Track best for this restart
    #             max_acq = acq_values.max().item()
    #             if max_acq > local_best_value:
    #                 local_best_value = max_acq
    #                 best_idx = acq_values.argmax().item()
    #                 if self.problem_type == 'tsp':
    #                     local_best_solution = samples[best_idx].clone()
    #                 else:
    #                     # For knapsack, enforce capacity constraint
    #                     local_best_solution = self._enforce_capacity_constraint(samples[best_idx])

    #         if (restart + 1) % 2 == 0:
    #             print(f"    Restart {restart + 1}/{self.n_restarts}, best acq: {local_best_value:.4f}")

    #         # Update global best across all restarts
    #         if local_best_value > best_acq_value:
    #             best_acq_value = local_best_value
    #             best_solution = local_best_solution

    #     # If no solution found, random fallback
    #     if best_solution is None:
    #         if self.problem_type == 'tsp':
    #             best_solution = torch.randperm(self.solution_dim, device=self.device)
    #         else:
    #             best_solution = torch.tensor(self.problem.random_solution(), device=self.device)

    #     return best_solution

    def _optimize_acquisition_pr(self, acq_function):
        best_solution = None
        best_acq_value = float('-inf')

        n = getattr(self, 'perm_dim', getattr(self, 'n_cities', self.solution_dim)) # self.n_cities  # PL operates in permutation space ONLY

        for restart in range(self.n_restarts):

            # θ lives in R^n (city preference scores)
            theta = torch.nn.Parameter(
                torch.randn(n, device=self.device) * 0.1
            )
            optimizer = torch.optim.Adam([theta], lr=self.lr)

            n_steps = 50
            local_best_value = float('-inf')
            local_best_solution = None

            for step in range(n_steps):
                optimizer.zero_grad()

                # Create permutation distribution
                # if self.problem_type == 'tsp':
                if self.problem_type in ['tsp', 'cvrp']:
                    distribution = PlackettLuceDistribution(theta, self.temperature)
                else:
                    distribution = BernoulliDistribution(theta, self.temperature)

                # ---- SAMPLE PERMUTATIONS ----
                with torch.no_grad():
                    perms = distribution.sample(torch.Size([self.mc_samples]))  # [M, n]

                    # ---- ENCODE PERMUTATIONS TO EDGE VECTORS ----
                    encoded_samples = torch.stack(
                        [self._encode_solution(p) for p in perms]
                    )  # [M, n*n]

                    # Normalize to [0,1] bounds
                    X_candidates = normalize(
                        encoded_samples.to(self.device),
                        bounds=self.X_bounds
                    ).unsqueeze(1)

                    # Evaluate acquisition
                    acq_values = acq_function(X_candidates)

                # Baseline update
                current_mean = acq_values.mean().item()
                self.baseline = (
                    self.baseline * self.baseline_decay
                    + current_mean * (1 - self.baseline_decay)
                )

                advantages = acq_values - self.baseline

                # Log-probabilities for REINFORCE gradient
                log_probs = distribution.log_prob(perms)

                # Score-function estimator
                surrogate_loss = -(advantages.detach() * log_probs).mean()
                surrogate_loss.backward()

                torch.nn.utils.clip_grad_norm_([theta], max_norm=1.0)
                optimizer.step()

                # Track best
                max_acq = acq_values.max().item()
                if max_acq > local_best_value:
                    local_best_value = max_acq
                    best_idx = acq_values.argmax().item()
                    local_best_solution = perms[best_idx].clone()

            # print restart status
            if (restart + 1) % 2 == 0:
                print(f"    Restart {restart + 1}/{self.n_restarts}, best acq: {local_best_value:.4f}")

            # update global best
            if local_best_value > best_acq_value:
                best_acq_value = local_best_value
                best_solution = local_best_solution

        if best_solution is None:
            best_solution = torch.randperm(n, device=self.device)

        return best_solution

    def _encode_cvrp_perm(self, perm):
        """Encode 0-indexed permutation to position encoding."""
        encoding = torch.zeros(self.n_customers, dtype=torch.float64)
        for pos, idx in enumerate(perm):
            encoding[idx.item()] = pos / max(1, self.n_customers - 1)
        return encoding
    
    def _encode_cvrp_perm(self, perm):
        """Encode 0-indexed permutation to position encoding."""
        encoding = torch.zeros(self.n_customers, dtype=torch.float64)
        for pos, idx in enumerate(perm):
            encoding[idx.item()] = pos / max(1, self.n_customers - 1)
        return encoding

    def optimize(self) -> Tuple[List, torch.Tensor]:
        """Run the full BOPR algorithm."""
        print("\n" + "="*70)
        print("STARTING PROPER BOPR OPTIMIZATION")
        print("="*70)

        # Initialize
        self._initialize_random_samples()

        # Main BO loop
        for iteration in range(self.n_iterations):
            print(f"\n--- Iteration {iteration + 1}/{self.n_iterations} ---")

            # Fit GPs
            print("Fitting GP models...")
            self._fit_gp_models()

            # Create acquisition
            print("Creating acquisition function...")
            acq_function = self._create_acquisition_function()

            # Optimize acquisition using PROPER PR!
            print("Optimizing acquisition with TRUE PR...")
            best_solution = self._optimize_acquisition_pr(acq_function)

            # Evaluate
            # if isinstance(best_solution, torch.Tensor):
            #     solution_for_eval = best_solution.cpu().tolist()
            # else:
            #     solution_for_eval = best_solution
            # Evaluate - convert to problem-specific format
            if self.problem_type == 'cvrp':
                # best_solution is 0-indexed permutation from Plackett-Luce
                if isinstance(best_solution, torch.Tensor):
                    perm = best_solution.cpu().tolist()
                else:
                    perm = best_solution
                # Convert to 1-indexed customer IDs and split into routes
                customer_order = [p + 1 for p in perm]
                solution_for_eval = self._split_into_routes(customer_order)
            elif isinstance(best_solution, torch.Tensor):
                solution_for_eval = best_solution.cpu().tolist()
            else:
                solution_for_eval = best_solution

            objectives_tuple = self.problem.evaluate(solution_for_eval)
            raw_obj = torch.tensor(objectives_tuple, dtype=torch.float64)
            print(f"New solution (raw, minimization space): {raw_obj.cpu().numpy()}")

            # Internal orientation
            if self.minimization_problem:
                objectives = -raw_obj
            else:
                objectives = raw_obj

            # Add to dataset (internal space)
            encoded = self._encode_solution(best_solution)
            self.X_observed.append(encoded.cpu())
            self.Y_observed.append(objectives.cpu())


            self.X_tensor = torch.stack(self.X_observed)
            self.Y_tensor = torch.stack(self.Y_observed)

            # Update reference point
            if not self.fixed_reference_point:
                y_min = self.Y_tensor.min(dim=0)[0]
                self.reference_point = tuple((y_min - 0.1 * y_min.abs()).tolist())

            # Stats
            pareto_mask = is_non_dominated(self.Y_tensor)
            pareto_Y = self.Y_tensor[pareto_mask]
            n_pareto = pareto_mask.sum().item()
            hv = self._compute_hypervolume(pareto_Y)

            print(f"Pareto front: {n_pareto} solutions, HV: {hv:.4f}")

        # Return final Pareto front
        pareto_mask = is_non_dominated(self.Y_tensor)
        pareto_Y_internal = self.Y_tensor[pareto_mask]
        pareto_X = self.X_tensor[pareto_mask]

        # Convert internal (maximization) objectives back to raw (minimization) space
        if self.minimization_problem:
            pareto_Y = -pareto_Y_internal
        else:
            pareto_Y = pareto_Y_internal

        pareto_solutions = []
        for encoded in pareto_X:
            if self.problem_type == 'tsp':
                solution = self._decode_solution(encoded * (self.solution_dim - 1))
            else:
                solution = self._decode_solution(encoded)
            pareto_solutions.append(solution)

        final_hv = self._compute_hypervolume(pareto_Y_internal)
        print("\n" + "="*70)
        print("OPTIMIZATION COMPLETE")
        print(f"Final Pareto front: {len(pareto_Y)} solutions")
        print(f"Final HV: {final_hv:.4f}")
        print("="*70)

        return pareto_solutions, pareto_Y

    def _split_into_routes(self, customer_order):
        """Split customer order into capacity-respecting routes."""
        routes = []
        current_route = []
        current_demand = 0
        
        for cid in customer_order:
            demand = self.customers[cid].demand
            if current_demand + demand <= self.vehicle_capacity:
                current_route.append(cid)
                current_demand += demand
            else:
                if current_route:
                    routes.append(current_route)
                current_route = [cid]
                current_demand = demand
        
        if current_route:
            routes.append(current_route)
        return routes

###############################################################################
# Wrapper for MOCO Evaluator
###############################################################################

class ProperBOPRWrapper:
    """Wrapper for MOCO evaluator compatibility."""

    def __init__(self, problem, **kwargs):
        """
        Initialize wrapper.

        Args:
            problem: Problem instance (BiObjectiveTSP, TriObjectiveTSP, or MultiObjectiveKnapsack)
            **kwargs: Algorithm parameters
        """
        self.problem = problem

        # Auto-detect problem type
        problem_class_name = problem.__class__.__name__
        if 'TSP' in problem_class_name:
            self.problem_type = 'tsp'
        elif 'Knapsack' in problem_class_name:
            self.problem_type = 'knapsack'
        elif 'CVRP' in problem_class_name:
            self.problem_type = 'cvrp'
        else:
            raise ValueError(f"Unknown problem class: {problem_class_name}")

        # Extract parameters with defaults
        self.n_initial = kwargs.get('n_initial', 50)
        self.n_iterations = kwargs.get('n_iterations', 30)
        self.mc_samples = kwargs.get('mc_samples', 64)  # 64 is better than 32 for REINFORCE
        self.lr = kwargs.get('lr', 0.05)
        self.temperature = kwargs.get('temperature', 1.0)  # 1.0 is standard for BOPR
        self.n_restarts = kwargs.get('n_restarts', 5)
        self.reference_point = kwargs.get('reference_point', None)
        self.sparse_gp = kwargs.get('sparse_gp', False)
        self.inducing_points = kwargs.get('inducing_points', 100)
        self.device = kwargs.get('device', None)

        print(f"\n{'='*70}")
        print("PROPER BOPR WRAPPER INITIALIZED")
        print(f"{'='*70}")
        print(f"Problem type: {self.problem_type}")
        print(f"Problem class: {problem_class_name}")
        print(f"Parameters:")
        print(f"  n_initial: {self.n_initial}")
        print(f"  n_iterations: {self.n_iterations}")
        print(f"  mc_samples: {self.mc_samples}")
        print(f"  lr: {self.lr}")
        print(f"  temperature: {self.temperature}")
        print(f"  n_restarts: {self.n_restarts}")
        print(f"  reference_point: {self.reference_point}")
        print(f"  sparse_gp: {self.sparse_gp}")
        if self.sparse_gp:
            print(f"  inducing_points: {self.inducing_points}")
        print(f"  device: {self.device if self.device else 'auto'}")
        print(f"{'='*70}\n")

    def run(self):
        """
        Run optimization and return solutions in MOCO format.

        Returns:
            List of [solution, objectives] pairs:
            [[sol1, [obj1, obj2, ...]], [sol2, [obj1, obj2, ...]], ...]
        """
        # Create BOPR instance with explicit parameters
        bopr = ProperBOPR(
            problem=self.problem,
            n_initial=self.n_initial,
            n_iterations=self.n_iterations,
            mc_samples=self.mc_samples,
            lr=self.lr,
            temperature=self.temperature,
            n_restarts=self.n_restarts,
            reference_point=self.reference_point,
            device=self.device,
            sparse_gp=self.sparse_gp,
            inducing_points=self.inducing_points
        )

        # Run optimization
        pareto_solutions, pareto_objectives = bopr.optimize()

        # Convert to MOCO format: [[solution, [obj1, obj2, ...]], ...]
        solutions = []
        for solution, obj in zip(pareto_solutions, pareto_objectives):
            # Convert tensors to lists for MOCO compatibility
            if isinstance(solution, torch.Tensor):
                sol_list = solution.cpu().numpy().tolist()
            else:
                sol_list = solution
            obj_list = obj.cpu().numpy().tolist()
            solutions.append([sol_list, obj_list])

        return solutions