# u_rankmoea.py
# U-RankMOEA: A Unified Rank-Based and Uncertainty-Aware Framework
# for High-Dimensional Expensive Multi-Objective Optimization
#
# Implementation based on the paper:
# "U-RANKMOEA: A UNIFIED RANK-BASED AND UNCERTAINTY-AWARE FRAMEWORK FOR 
# HIGH-DIMENSIONAL EXPENSIVE MULTI-OBJECTIVE OPTIMIZATION"

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.size'] = 12
matplotlib.rcParams['figure.dpi'] = 300
import seaborn as sns
sns.set_style("whitegrid")
from scipy.stats import qmc
from sklearn.preprocessing import StandardScaler
import time
import os
import warnings
warnings.filterwarnings("ignore")

# Configuration for U-RankMOEA
CFG = {
    "D": 30,  # Decision variables
    "M": 2,   # Objectives  
    "N_init": 100,  # Initial samples
    "NP": 50,  # Population size
    "maxFEs": 300,  # Maximum function evaluations
    "batch_size": 6,  # Batch size for evaluation
    
    # Bayesian classifier parameters
    "clf_hidden": 128,
    "clf_dropout": 0.3,
    "clf_epochs": 60,
    "clf_lr": 0.001,
    "clf_ensembles": 5,
    "mc_samples_base": 8,
    "mc_samples_max": 32,
    "mc_threshold": 0.1,
    
    # Deep GP surrogate parameters
    "dgp_hidden": 128,
    "dgp_layers": 2,
    "dgp_epochs": 80,
    "dgp_lr": 0.001,
    "dgp_ensembles": 5,
    "inducing_points": 50,
    
    # Acquisition network parameters
    "acq_hidden": 64,
    "acq_epochs": 40,
    "acq_lr": 0.001,
    "history_buffer": 200,
    
    # Complexity-aware parameters
    "screening_pool": 1000,
    "expensive_eval": 100,
    "uncertainty_weight": 0.3,
    
    "seed": 42,
}

# Problem definitions
def zdt1(x):
    """ZDT1 benchmark problem"""
    x = np.atleast_2d(x)
    D = x.shape[1]
    f1 = x[:, 0]
    g = 1 + 9.0 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - np.sqrt(f1 / g)
    f2 = g * h
    return np.column_stack([f1, f2])

def zdt2(x):
    """ZDT2 benchmark problem"""
    x = np.atleast_2d(x)
    D = x.shape[1]
    f1 = x[:, 0]
    g = 1 + 9.0 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - (f1 / g) ** 2
    f2 = g * h
    return np.column_stack([f1, f2])

def dtlz2(x, M=2):
    """DTLZ2 benchmark problem"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    g = np.sum((x[:, M-1:] - 0.5) ** 2, axis=1)
    
    f = np.zeros((N, M))
    for i in range(M):
        f[:, i] = (1 + g)
        for j in range(M - i - 1):
            f[:, i] *= np.cos(x[:, j] * np.pi / 2)
        if i > 0:
            f[:, i] *= np.sin(x[:, M - i - 1] * np.pi / 2)
    
    return f

def true_pareto_front_zdt1(n_points=100):
    """Generate true Pareto front for ZDT1"""
    f1 = np.linspace(0, 1, n_points)
    f2 = 1 - np.sqrt(f1)
    return np.column_stack([f1, f2])

# Utility functions
def lhs_samples(n, D, lower=0.0, upper=1.0, seed=None):
    """Latin Hypercube Sampling"""
    sampler = qmc.LatinHypercube(d=D, seed=seed)
    u = sampler.random(n)
    return qmc.scale(u, lower, upper)

def nondominated_sort_fast(objs):
    """Fast non-dominated sorting"""
    N = objs.shape[0]
    if N == 0:
        return np.array([]), []
    
    dom_count = np.zeros(N, dtype=int)
    dominates = [[] for _ in range(N)]
    
    for p in range(N):
        dominated_mask = np.all(objs[p] <= objs, axis=1) & np.any(objs[p] < objs, axis=1)
        dominates[p] = np.where(dominated_mask)[0].tolist()
        dom_count[p] = np.sum(np.all(objs <= objs[p], axis=1) & np.any(objs < objs[p], axis=1))
    
    current = np.where(dom_count == 0)[0].tolist()
    frontno = np.full(N, np.inf)
    fronts = []
    f = 1
    
    while current:
        fronts.append(current)
        for p in current:
            frontno[p] = f
        Q = []
        for p in current:
            for q in dominates[p]:
                dom_count[q] -= 1
                if dom_count[q] == 0:
                    Q.append(q)
        current = Q
        f += 1
    
    return frontno.astype(int), fronts

def nondominated_frontpoints(objs):
    """Get non-dominated front points"""
    if len(objs) == 0:
        return np.array([])
    frontno, _ = nondominated_sort_fast(objs)
    return objs[frontno == 1]

def calculate_hypervolume_simple(points, ref_point=None):
    """Simple HV calculation for 2D problems"""
    if len(points) == 0:
        return 0.0
    
    points = np.atleast_2d(points)
    if points.shape[1] != 2:
        # For higher dimensions, use dominated hyperrectangle approximation
        if ref_point is None:
            ref_point = np.max(points, axis=0) + 1
        
        front = nondominated_frontpoints(points)
        if len(front) == 0:
            return 0.0
        
        # Simple approximation for higher dimensions
        volumes = []
        for point in front:
            vol = np.prod(np.maximum(0, ref_point - point))
            volumes.append(vol)
        return np.sum(volumes)
    
    if ref_point is None:
        ref_point = np.array([1.1, 1.1])  # For ZDT problems
    
    front = nondominated_frontpoints(points)
    if len(front) == 0:
        return 0.0
    
    # Sort by first objective
    front = front[np.argsort(front[:, 0])]
    
    # Calculate 2D hypervolume
    hv = 0.0
    prev_x = 0.0
    for f1, f2 in front:
        if f1 > prev_x and f2 < ref_point[1]:
            hv += (f1 - prev_x) * (ref_point[1] - f2)
            prev_x = f1
    
    return hv

def igd_metric(obtained_front, true_front):
    """Inverted Generational Distance"""
    if len(obtained_front) == 0:
        return np.inf
    if len(true_front) == 0:
        return np.inf
    
    distances = []
    for true_point in true_front:
        min_dist = np.min(np.sqrt(np.sum((obtained_front - true_point) ** 2, axis=1)))
        distances.append(min_dist)
    return np.mean(distances)

# Neural Network Components

class BayesianRankClassifier(nn.Module):
    """Bayesian rank-based classifier with complexity controls (Section 3.3)"""
    
    def __init__(self, D, hidden=128, n_classes=4, dropout=0.3):
        super().__init__()
        self.dropout_p = dropout
        
        # Network architecture as specified in paper
        self.net = nn.Sequential(
            nn.Linear(D, hidden),
            nn.LayerNorm(hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(hidden, hidden),
            nn.LayerNorm(hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(hidden, hidden // 2),
            nn.LayerNorm(hidden // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(hidden // 2, n_classes)
        )
        
        # Temperature parameter for calibration
        self.temperature = nn.Parameter(torch.ones(1))
        
    def forward(self, x, mc_samples=1):
        if mc_samples == 1:
            logits = self.net(x)
            return logits / torch.clamp(self.temperature, min=0.1, max=10.0)
        else:
            # MC-Dropout sampling
            outputs = []
            self.train()  # Enable dropout
            for _ in range(mc_samples):
                with torch.no_grad():
                    logits = self.net(x)
                    scaled_logits = logits / torch.clamp(self.temperature, min=0.1, max=10.0)
                    outputs.append(F.softmax(scaled_logits, dim=1))
            self.eval()
            return torch.stack(outputs, dim=0)

class DeepGPSurrogate(nn.Module):
    """Deep Gaussian Process surrogate with neural mean functions (Section 3.4)"""
    
    def __init__(self, D, hidden=128, n_layers=2, n_inducing=50):
        super().__init__()
        self.n_layers = n_layers
        self.n_inducing = n_inducing
        
        # Neural feature extractors for each layer
        self.feature_nets = nn.ModuleList()
        for l in range(n_layers):
            input_dim = D if l == 0 else hidden
            self.feature_nets.append(
                nn.Sequential(
                    nn.Linear(input_dim, hidden),
                    nn.LayerNorm(hidden),
                    nn.ReLU(),
                    nn.Dropout(0.2),
                    nn.Linear(hidden, hidden),
                )
            )
        
        # Mean and variance heads
        self.mean_head = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, 1)
        )
        
        self.var_head = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, 1),
            nn.Softplus()
        )
        
        # Inducing point locations (learnable)
        self.inducing_points = nn.Parameter(torch.randn(n_inducing, D))
        
    def forward(self, x):
        # Forward through layers
        h = x
        for layer_net in self.feature_nets:
            h = layer_net(h)
        
        # Predictive mean and variance
        mean = self.mean_head(h).squeeze(-1)
        var = self.var_head(h).squeeze(-1) + 1e-6
        
        return mean, var

class HistoryAwareAcquisition(nn.Module):
    """History-aware acquisition network (Section 3.5)"""
    
    def __init__(self, feat_dim, hidden=64):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(feat_dim, hidden),
            nn.LayerNorm(hidden),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden, hidden),
            nn.LayerNorm(hidden),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        
        # Hypervolume improvement predictor
        self.hv_head = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, 1),
            nn.Sigmoid()
        )
        
        # Diversity score predictor  
        self.div_head = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        features = self.net(x)
        hv_score = self.hv_head(features).squeeze(-1)
        div_score = self.div_head(features).squeeze(-1)
        return hv_score, div_score

# Main U-RankMOEA Class
class URankMOEA:
    """U-RankMOEA: Unified Rank-Based and Uncertainty-Aware Framework"""
    
    def __init__(self, cfg, problem_func=zdt1):
        self.cfg = cfg
        self.D = cfg["D"]
        self.M = cfg["M"]
        self.N_init = cfg["N_init"]
        self.NP = cfg["NP"]
        self.maxFEs = cfg["maxFEs"]
        self.batch_size = cfg["batch_size"]
        self.problem_func = problem_func
        
        # Set random seeds
        np.random.seed(cfg["seed"])
        torch.manual_seed(cfg["seed"])
        if torch.cuda.is_available():
            torch.cuda.manual_seed(cfg["seed"])
        
        # Archives and history
        self.archive_X = None
        self.archive_Y = None
        self.FEs = 0
        
        # Scalers
        self.x_scaler = StandardScaler()
        self.y_scalers = [StandardScaler() for _ in range(self.M)]
        
        # Model components
        self.clf_ensemble = []
        self.dgp_ensemble = []
        self.acq_net = None
        
        # History tracking
        self.hv_history = []
        self.igd_history = []
        self.history_buffer = []
        
        # Device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"U-RankMOEA initialized on device: {self.device}")
    
    def initial_sampling(self):
        """Initial Latin Hypercube Sampling"""
        X0 = lhs_samples(self.N_init, self.D, seed=self.cfg["seed"])
        Y0 = self.problem_func(X0)
        
        self.archive_X = X0.copy()
        self.archive_Y = Y0.copy()
        self.FEs = self.N_init
        
        # Fit scalers
        self.x_scaler.fit(self.archive_X)
        for m in range(self.M):
            self.y_scalers[m].fit(self.archive_Y[:, m:m+1])
        
        # Calculate initial metrics
        hv = calculate_hypervolume_simple(self.archive_Y)
        true_front = true_pareto_front_zdt1()
        current_front = nondominated_frontpoints(self.archive_Y)
        igd = igd_metric(current_front, true_front)
        
        self.hv_history.append(hv)
        self.igd_history.append(igd)
        
        print(f"Initial sampling: {self.N_init} points, HV={hv:.4f}, IGD={igd:.4f}")
    
    def train_classifier_ensemble(self, X, Y):
        """Train Bayesian rank-based classifier ensemble"""
        if len(X) < 10:
            return
        
        Xs = self.x_scaler.transform(X)
        Xt = torch.tensor(Xs, dtype=torch.float32).to(self.device)
        
        # Generate rank labels
        frontno, fronts = nondominated_sort_fast(Y)
        ranks = np.minimum(frontno, 4).astype(int) - 1  # Ranks 0-3
        yt = torch.tensor(ranks, dtype=torch.long).to(self.device)
        
        self.clf_ensemble = []
        n_ens = self.cfg["clf_ensembles"]
        
        for k in range(n_ens):
            net = BayesianRankClassifier(
                self.D, 
                hidden=self.cfg["clf_hidden"],
                dropout=self.cfg["clf_dropout"]
            ).to(self.device)
            
            optimizer = optim.Adam(net.parameters(), lr=self.cfg["clf_lr"])
            criterion = nn.CrossEntropyLoss()
            
            # Training loop
            for epoch in range(self.cfg["clf_epochs"]):
                net.train()
                optimizer.zero_grad()
                
                logits = net(Xt)
                loss = criterion(logits, yt)
                
                loss.backward()
                optimizer.step()
            
            # Temperature scaling
            with torch.no_grad():
                net.eval()
                logits = net(Xt)
                # Simple temperature scaling
                temperatures = torch.linspace(0.1, 3.0, 30).to(self.device)
                best_nll = float('inf')
                for temp in temperatures:
                    scaled_logits = logits / temp
                    nll = criterion(scaled_logits, yt)
                    if nll < best_nll:
                        best_nll = nll
                        net.temperature.data = temp.unsqueeze(0)
            
            self.clf_ensemble.append(net)
    
    def train_dgp_ensemble(self, X, Y):
        """Train Deep GP surrogate ensemble"""
        if len(X) < 10:
            return
        
        Xs = self.x_scaler.transform(X)
        Xt = torch.tensor(Xs, dtype=torch.float32).to(self.device)
        
        self.dgp_ensemble = []
        n_ens = self.cfg["dgp_ensembles"]
        
        for m in range(self.M):  # One ensemble per objective
            y = Y[:, m]
            y_scaled = self.y_scalers[m].transform(y.reshape(-1, 1)).flatten()
            yt = torch.tensor(y_scaled, dtype=torch.float32).to(self.device)
            
            members = []
            for e in range(n_ens):
                net = DeepGPSurrogate(
                    self.D,
                    hidden=self.cfg["dgp_hidden"],
                    n_layers=self.cfg["dgp_layers"],
                    n_inducing=self.cfg["inducing_points"]
                ).to(self.device)
                
                optimizer = optim.Adam(net.parameters(), lr=self.cfg["dgp_lr"])
                
                # Training loop
                for epoch in range(self.cfg["dgp_epochs"]):
                    net.train()
                    optimizer.zero_grad()
                    
                    mean_pred, var_pred = net(Xt)
                    
                    # Negative log-likelihood loss
                    dist = Normal(mean_pred, torch.sqrt(var_pred))
                    nll_loss = -dist.log_prob(yt).mean()
                    
                    # Add MSE for stability
                    mse_loss = F.mse_loss(mean_pred, yt)
                    total_loss = nll_loss + 0.1 * mse_loss
                    
                    total_loss.backward()
                    optimizer.step()
                
                members.append(net)
            
            self.dgp_ensemble.append(members)
    
    def predict_classifier_with_uncertainty(self, X_cands):
        """Classifier prediction with uncertainty quantification"""
        if len(self.clf_ensemble) == 0:
            n = len(X_cands)
            return (np.ones((n, 4)) * 0.25,  # Equal probabilities
                   np.ones(n) * 0.5)  # Dummy uncertainty
        
        Xs = self.x_scaler.transform(X_cands)
        Xt = torch.tensor(Xs, dtype=torch.float32).to(self.device)
        
        ensemble_probs = []
        mc_probs = []
        
        with torch.no_grad():
            for net in self.clf_ensemble:
                net.eval()
                
                # Standard prediction
                logits = net(Xt)
                probs = F.softmax(logits, dim=1).cpu().numpy()
                ensemble_probs.append(probs)
                
                # Adaptive MC-Dropout
                mc_samples = net(Xt, mc_samples=self.cfg["mc_samples_base"])
                mc_mean = mc_samples.mean(dim=0).cpu().numpy()
                mc_probs.append(mc_mean)
        
        ensemble_probs = np.stack(ensemble_probs, axis=0)
        mc_probs = np.stack(mc_probs, axis=0)
        
        # Mean predictions
        mean_probs = np.mean(ensemble_probs, axis=0)
        
        # Epistemic uncertainty (mutual information approximation)
        eps = 1e-12
        aleatoric = -np.sum(mean_probs * np.log(mean_probs + eps), axis=1)
        epistemic = np.var(ensemble_probs, axis=0).mean(axis=1)
        total_uncertainty = aleatoric + epistemic
        
        return mean_probs, total_uncertainty
    
    def predict_dgp_with_uncertainty(self, X_cands):
        """Deep GP prediction with uncertainty decomposition"""
        if len(self.dgp_ensemble) == 0 or len(self.dgp_ensemble[0]) == 0:
            n = len(X_cands)
            mean_pred = self.problem_func(X_cands)
            uncertainty = np.ones((n, self.M)) * 0.1
            return mean_pred, uncertainty, uncertainty
        
        Xs = self.x_scaler.transform(X_cands)
        Xt = torch.tensor(Xs, dtype=torch.float32).to(self.device)
        
        N = len(X_cands)
        M = self.M
        ens = len(self.dgp_ensemble[0])
        
        all_means = np.zeros((ens, N, M))
        all_vars = np.zeros((ens, N, M))
        
        with torch.no_grad():
            for m in range(M):
                for e, net in enumerate(self.dgp_ensemble[m]):
                    net.eval()
                    mean_pred, var_pred = net(Xt)
                    
                    # Transform back to original scale
                    mean_original = self.y_scalers[m].inverse_transform(
                        mean_pred.cpu().numpy().reshape(-1, 1)
                    ).flatten()
                    var_original = var_pred.cpu().numpy()
                    
                    all_means[e, :, m] = mean_original
                    all_vars[e, :, m] = var_original
        
        # Uncertainty decomposition
        mean_prediction = all_means.mean(axis=0)
        epistemic_uncertainty = all_means.var(axis=0)
        aleatoric_uncertainty = all_vars.mean(axis=0)
        
        return mean_prediction, epistemic_uncertainty, aleatoric_uncertainty
    
    def build_acquisition_features(self, X_cands, clf_probs, clf_uncertainty, 
                                 mean_pred, epistemic_unc, aleatoric_unc):
        """Build feature vector for acquisition network"""
        # Historical hypervolume statistics
        if len(self.hv_history) > 1:
            recent_hv_changes = np.diff(self.hv_history[-10:])
            mu_hv = np.mean(recent_hv_changes)
            sigma_hv = np.std(recent_hv_changes)
        else:
            mu_hv = sigma_hv = 0.0
        
        # Construct feature vector (Eq. 16 in paper)
        n_cands = len(X_cands)
        features = []
        
        # Surrogate predictions (M dimensions)
        features.append(mean_pred)
        
        # GP uncertainties (2M dimensions) 
        features.append(epistemic_unc)
        features.append(aleatoric_unc)
        
        # Classifier predictions (K-1 = 3 dimensions, use first 3 classes)
        features.append(clf_probs[:, :3])
        
        # Classifier uncertainty (1 dimension)
        features.append(clf_uncertainty.reshape(-1, 1))
        
        # Historical HV statistics (2 dimensions)
        mu_hv_vec = np.full((n_cands, 1), mu_hv)
        sigma_hv_vec = np.full((n_cands, 1), sigma_hv)
        features.append(mu_hv_vec)
        features.append(sigma_hv_vec)
        
        feat = np.hstack(features)
        return feat
    
    def train_acquisition_network(self, X_pool, hv_improvements):
        """Train history-aware acquisition network"""
        if len(X_pool) < 10 or len(hv_improvements) < 10:
            return
        
        # Build features
        clf_probs, clf_unc = self.predict_classifier_with_uncertainty(X_pool)
        mean_pred, epistemic_unc, aleatoric_unc = self.predict_dgp_with_uncertainty(X_pool)
        
        feat = self.build_acquisition_features(
            X_pool, clf_probs, clf_unc, mean_pred, epistemic_unc, aleatoric_unc
        )
        
        feat_dim = feat.shape[1]
        
        if self.acq_net is None:
            self.acq_net = HistoryAwareAcquisition(
                feat_dim, hidden=self.cfg["acq_hidden"]
            ).to(self.device)
        
        # Normalize targets
        targets = np.array(hv_improvements)
        if np.std(targets) > 0:
            targets = (targets - targets.min()) / (targets.max() - targets.min() + 1e-12)
        
        feat_t = torch.tensor(feat, dtype=torch.float32).to(self.device)
        target_t = torch.tensor(targets, dtype=torch.float32).to(self.device)
        
        optimizer = optim.Adam(self.acq_net.parameters(), lr=self.cfg["acq_lr"])
        criterion = nn.MSELoss()
        
        # Training loop
        for epoch in range(self.cfg["acq_epochs"]):
            self.acq_net.train()
            optimizer.zero_grad()
            
            hv_pred, div_pred = self.acq_net(feat_t)
            loss = criterion(hv_pred, target_t)
            
            loss.backward()
            optimizer.step()
    
    def rank_conditioned_offspring(self, pool_size):
        """Generate rank-conditioned offspring (Eq. 8 in paper)"""
        if len(self.archive_X) < 3:
            return lhs_samples(pool_size, self.D)
        
        # Get ranks for current archive
        frontno, fronts = nondominated_sort_fast(self.archive_Y)
        
        offspring = []
        for _ in range(pool_size):
            # Sample parents based on rank
            if len(fronts) > 0 and len(fronts[0]) >= 3:
                # Parents from best front
                parent_indices = np.random.choice(fronts[0], size=3, replace=True)
            else:
                # Random parents
                parent_indices = np.random.choice(len(self.archive_X), size=3, replace=True)
            
            xa, xb, xc = self.archive_X[parent_indices]
            
            # Differential evolution with rank conditioning
            if frontno[parent_indices[0]] == 1:  # Parent from rank 1
                F = 0.8
                v = xa + F * (xb - xc)
            else:
                F = 0.5  # Reduced factor for lower ranks
                best_idx = fronts[0][np.random.choice(len(fronts[0]))]
                xbest = self.archive_X[best_idx]
                v = xa + F * (xbest - xa)
            
            v = np.clip(v, 0, 1)
            offspring.append(v)
        
        return np.array(offspring)
    
    def two_stage_screening(self, iteration):
        """Two-stage candidate screening (complexity-aware design)"""
        # Stage 1: Generate large pool with cheap operators
        large_pool = self.rank_conditioned_offspring(self.cfg["screening_pool"])
        
        # Stage 2: Cheap proxy scoring using classifier
        clf_probs, clf_unc = self.predict_classifier_with_uncertainty(large_pool)
        
        # Simple proxy score: probability of being in best rank + uncertainty
        proxy_scores = clf_probs[:, 0] + self.cfg["uncertainty_weight"] * clf_unc
        
        # Select top candidates for expensive evaluation
        top_indices = np.argsort(proxy_scores)[-self.cfg["expensive_eval"]:]
        selected_pool = large_pool[top_indices]
        
        return selected_pool
    
    def select_batch_for_evaluation(self, candidates):
        """Select final batch using acquisition network"""
        if self.acq_net is None:
            # Fallback: select based on classifier uncertainty
            _, clf_unc = self.predict_classifier_with_uncertainty(candidates)
            top_indices = np.argsort(clf_unc)[-self.batch_size:]
            return candidates[top_indices]
        
        # Use acquisition network
        clf_probs, clf_unc = self.predict_classifier_with_uncertainty(candidates)
        mean_pred, epistemic_unc, aleatoric_unc = self.predict_dgp_with_uncertainty(candidates)
        
        feat = self.build_acquisition_features(
            candidates, clf_probs, clf_unc, mean_pred, epistemic_unc, aleatoric_unc
        )
        
        feat_t = torch.tensor(feat, dtype=torch.float32).to(self.device)
        
        with torch.no_grad():
            self.acq_net.eval()
            hv_scores, div_scores = self.acq_net(feat_t)
            
            # Combined acquisition score
            total_scores = (hv_scores + self.cfg["uncertainty_weight"] * div_scores).cpu().numpy()
        
        # Select top candidates
        top_indices = np.argsort(total_scores)[-self.batch_size:]
        return candidates[top_indices]
    
    def run(self):
        """Main optimization loop"""
        print("=== Starting U-RankMOEA ===")
        
        # Initial sampling
        self.initial_sampling()
        
        iteration = 0
        while self.FEs < self.maxFEs:
            iteration += 1
            print(f"\n=== Iteration {iteration} | FEs: {self.FEs}/{self.maxFEs} ===")
            
            prev_hv = self.hv_history[-1] if self.hv_history else 0
            
            # 1. Train classifier ensemble
            self.train_classifier_ensemble(self.archive_X, self.archive_Y)
            
            # 2. Train Deep GP ensemble
            self.train_dgp_ensemble(self.archive_X, self.archive_Y)
            
            # 3. Train acquisition network (if enough history)
            if len(self.history_buffer) > 20:
                pool_x = [item[0] for item in self.history_buffer]
                hv_improvements = [item[1] for item in self.history_buffer]
                self.train_acquisition_network(
                    np.array(pool_x), np.array(hv_improvements)
                )
            
            # 4. Two-stage candidate screening
            candidates = self.two_stage_screening(iteration)
            
            # 5. Select batch for evaluation
            selected_X = self.select_batch_for_evaluation(candidates)
            
            # 6. Evaluate selected candidates
            if len(selected_X) > 0:
                selected_Y = self.problem_func(selected_X)
                
                # Update archive
                self.archive_X = np.vstack([self.archive_X, selected_X])
                self.archive_Y = np.vstack([self.archive_Y, selected_Y])
                self.FEs += len(selected_X)
                
                # Update history
                current_hv = calculate_hypervolume_simple(self.archive_Y)
                hv_improvement = current_hv - prev_hv
                
                self.hv_history.append(current_hv)
                
                # Add to history buffer
                for i in range(len(selected_X)):
                    self.history_buffer.append((selected_X[i], hv_improvement))
                
                # Keep buffer size manageable
                if len(self.history_buffer) > self.cfg["history_buffer"]:
                    self.history_buffer = self.history_buffer[-self.cfg["history_buffer"]:]
                
                # Calculate IGD
                true_front = true_pareto_front_zdt1()
                current_front = nondominated_frontpoints(self.archive_Y)
                igd = igd_metric(current_front, true_front)
                self.igd_history.append(igd)
                
                print(f"Selected {len(selected_X)} candidates")
                print(f"HV: {current_hv:.4f} (+{hv_improvement:.4e})")
                print(f"IGD: {igd:.4f}")
                print(f"Front size: {len(current_front)}")
        
        print(f"\nOptimization completed. Final HV: {self.hv_history[-1]:.4f}")
        return self.archive_X, self.archive_Y
    
    def plot_results(self):
        """Plot optimization results"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        
        # Hypervolume convergence
        ax1.plot(self.hv_history, 'b-', linewidth=2)
        ax1.fill_between(range(len(self.hv_history)), self.hv_history, alpha=0.3)
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Hypervolume')
        ax1.set_title('Hypervolume Convergence')
        ax1.grid(True, alpha=0.7)
        
        # IGD convergence
        ax2.semilogy(self.igd_history, 'r-', linewidth=2)
        ax2.fill_between(range(len(self.igd_history)), self.igd_history, alpha=0.3)
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('IGD (log scale)')
        ax2.set_title('IGD Convergence')
        ax2.grid(True, alpha=0.7)
        
        # Pareto front comparison
        if self.M == 2:
            true_front = true_pareto_front_zdt1()
            ax3.plot(true_front[:, 0], true_front[:, 1], 'k-', linewidth=3, 
                    label='True Pareto Front', alpha=0.8)
            
            final_front = nondominated_frontpoints(self.archive_Y)
            if len(final_front) > 0:
                sorted_front = final_front[np.argsort(final_front[:, 0])]
                ax3.scatter(sorted_front[:, 0], sorted_front[:, 1], 
                           c='red', s=80, alpha=0.8, label=f'U-RankMOEA (n={len(final_front)})',
                           edgecolors='darkred', linewidth=2)
            
            ax3.set_xlabel('Objective 1')
            ax3.set_ylabel('Objective 2')
            ax3.set_title('Pareto Front Comparison')
            ax3.legend()
            ax3.grid(True, alpha=0.7)
        
        # Objective space with all points
        if self.M == 2:
            frontno, fronts = nondominated_sort_fast(self.archive_Y)
            colors = ['red', 'orange', 'yellow', 'green', 'blue']
            for f in range(min(5, len(fronts))):
                mask = frontno == f + 1
                if np.any(mask):
                    alpha = 0.8 if f == 0 else 0.4
                    size = 60 if f == 0 else 20
                    ax4.scatter(self.archive_Y[mask, 0], self.archive_Y[mask, 1],
                               c=colors[f], alpha=alpha, s=size, 
                               label=f'Front {f+1}' if f < 3 else '')
            
            ax4.set_xlabel('Objective 1')
            ax4.set_ylabel('Objective 2') 
            ax4.set_title('Archive Evolution')
            ax4.legend()
            ax4.grid(True, alpha=0.7)
        
        plt.tight_layout()
        plt.savefig('urankmoea_results.png', dpi=300, bbox_inches='tight')
        plt.show()

# Example usage and testing
if __name__ == "__main__":
    print("=== U-RankMOEA Implementation Test ===")
    
    # Test configuration
    test_cfg = CFG.copy()
    test_cfg.update({
        "D": 30,
        "maxFEs": 300,
        "N_init": 60,
        "batch_size": 4,
    })
    
    # Initialize and run
    optimizer = URankMOEA(test_cfg, problem_func=zdt1)
    final_X, final_Y = optimizer.run()
    
    # Results summary
    if len(final_Y) > 0:
        final_hv = calculate_hypervolume_simple(final_Y)
        true_front = true_pareto_front_zdt1()
        final_front = nondominated_frontpoints(final_Y)
        final_igd = igd_metric(final_front, true_front)
        
        print(f"\n=== FINAL RESULTS ===")
        print(f"Function Evaluations: {optimizer.FEs}")
        print(f"Final Hypervolume: {final_hv:.6f}")
        print(f"Final IGD: {final_igd:.6f}")
        print(f"Final Front Size: {len(final_front)}")
    
    # Plot results
    optimizer.plot_results()
    
    print("\n✅ U-RankMOEA execution completed successfully!")