"""
Select-Best Generation Using Uncertainty Quantification.

This module implements the "select-best" experiment where UQ is used at inference
time to select the highest-quality candidate among multiple generations.

For each target y*, we generate K candidates and compare:
- Random selection: pick one uniformly at random
- UQ-guided selection: pick the candidate with lowest uncertainty
- Oracle selection: pick the candidate with lowest ground-truth error

Key uncertainty metrics (Diag-CFM specific):
- Zero-Deviation: ||output[:, :L]||² after synthesis pass (should be ~0)
- Self-Consistency: ||y_reconstructed - y*||² after round-trip through model
- FM Loss: Flow matching loss at t=0.5 interpolation point

Usage:
    # Run on all datasets
    python -m uq_diagcfm.select_best

    # Run on specific dataset
    python -m uq_diagcfm.select_best --dataset gas_turbine
    python -m uq_diagcfm.select_best --dataset unifoil
    python -m uq_diagcfm.select_best --dataset dtlz --P 50

    # Run on all DTLZ dimensions
    python -m uq_diagcfm.select_best --dataset dtlz --P all
"""

import argparse
import gc
import json
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

from uq_diagcfm.solvers import euler_method, euler_method_with_conditioning
from uq_diagcfm.uq_evaluation_utils import compute_fm_losses
from uq_diagcfm.utils import get_device

# =============================================================================
# Default Constants
# =============================================================================

DEFAULT_N_SAMPLES = 1000  # Consistent across all datasets
DEFAULT_K_CANDIDATES = 10  # Consistent across all datasets
DEFAULT_RANDOM_SEED = 42  # For reproducible random sampling of test data

# Runtime configuration (can be modified via CLI)
_config = {
    "n_samples": DEFAULT_N_SAMPLES,
    "k_candidates": DEFAULT_K_CANDIDATES,
    "random_seed": DEFAULT_RANDOM_SEED,
}


# =============================================================================
# Core Candidate Generation
# =============================================================================


def generate_k_candidates(
    model: torch.nn.Module,
    target_labels: torch.Tensor,
    num_design_params: int,
    num_labels: int,
    k: int,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
    noise_distribution: str = "uniform",
) -> torch.Tensor:
    """Generate K candidate designs for each target label.

    Args:
        model: Flow matching model.
        target_labels: Target labels (N, L).
        num_design_params: Design space dimension P.
        num_labels: Label space dimension L.
        k: Number of candidates to generate per target.
        conditioning: Optional conditioning (N, C).
        steps: Number of Euler integration steps.
        noise_distribution: "uniform" for rand, "normal" for randn.

    Returns:
        Generated designs of shape (N, K, P).
    """
    device = target_labels.device
    n_targets = target_labels.shape[0]
    all_candidates = []

    for _ in range(k):
        # Generate noise for this batch of candidates
        if noise_distribution == "uniform":
            noise = torch.rand(n_targets, num_design_params, device=device)
        else:
            noise = torch.randn(n_targets, num_design_params, device=device)

        # Augment labels with noise
        augmented_labels = torch.cat([target_labels, noise], dim=1)

        # Generate designs (inverse pass)
        if conditioning is not None:
            result = euler_method_with_conditioning(
                model=model,
                input=augmented_labels,
                conditioning=conditioning,
                start_t=1,
                end_t=0,
                steps=steps,
            )
        else:
            result = euler_method(
                model=model,
                input=augmented_labels,
                start_t=1,
                end_t=0,
                steps=steps,
            )

        # Extract design parameters
        designs = result[:, num_labels:]
        all_candidates.append(designs)

    # Stack to (N, K, P)
    return torch.stack(all_candidates, dim=1)


# =============================================================================
# Uncertainty Metrics
# =============================================================================


def compute_candidate_uncertainties_ensemble(
    candidates: torch.Tensor,
    models: List[torch.nn.Module],
    num_labels: int,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
    batch_size: int = 200,
) -> torch.Tensor:
    """Compute ensemble variance for each candidate.

    For each candidate design, run all models in forward direction (analysis pass)
    and compute variance of predicted labels across models.

    Args:
        candidates: Candidate designs (N, K, P).
        models: List of ensemble models.
        num_labels: Label space dimension L.
        conditioning: Optional conditioning (N, C).
        steps: Number of Euler steps.
        batch_size: Batch size for processing to manage memory.

    Returns:
        Ensemble variance uncertainties (N, K).
    """
    device = candidates.device
    n_targets, k, num_design_params = candidates.shape
    uncertainties = []

    for ki in tqdm(range(k), desc="Computing ensemble variance"):
        cand_k = candidates[:, ki, :]  # (N, P)

        # Process in batches for memory efficiency
        all_variances = []
        for batch_start in range(0, n_targets, batch_size):
            batch_end = min(batch_start + batch_size, n_targets)
            batch_cand = cand_k[batch_start:batch_end]
            batch_cond = None
            if conditioning is not None:
                batch_cond = conditioning[batch_start:batch_end]

            # Augment designs with zeros for label dimensions
            augmented_designs = torch.cat(
                [
                    torch.zeros(batch_cand.shape[0], num_labels, device=device),
                    batch_cand,
                ],
                dim=1,
            )

            # Get predictions from all models
            predictions_list = []
            for m in models:
                with torch.no_grad():
                    if batch_cond is not None:
                        pred = euler_method_with_conditioning(
                            model=m,
                            input=augmented_designs,
                            conditioning=batch_cond,
                            start_t=0,
                            end_t=1,
                            steps=steps,
                        )[:, :num_labels]
                    else:
                        pred = euler_method(
                            model=m,
                            input=augmented_designs,
                            start_t=0,
                            end_t=1,
                            steps=steps,
                        )[:, :num_labels]

                    # Handle inf/nan
                    pred = torch.nan_to_num(pred, nan=0.0, posinf=1e6, neginf=-1e6)
                    predictions_list.append(pred)

            # Stack: (batch, L, num_models)
            ensemble_predictions = torch.stack(predictions_list, dim=2)

            # Variance across models, mean over labels: (batch,)
            batch_variance = ensemble_predictions.var(dim=2).mean(dim=1)
            all_variances.append(batch_variance)

        uncertainties.append(torch.cat(all_variances))

    return torch.stack(uncertainties, dim=1)  # (N, K)


def compute_candidate_uncertainties_fm(
    candidates: torch.Tensor,
    target_labels: torch.Tensor,
    model: torch.nn.Module,
    num_labels: int,
    num_design_params: int,
    conditioning: Optional[torch.Tensor] = None,
    batch_size: int = 100,
) -> torch.Tensor:
    """Compute FM loss for each candidate.

    Args:
        candidates: Candidate designs (N, K, P).
        target_labels: Target labels (N, L).
        model: Reference model.
        num_labels: Label space dimension L.
        num_design_params: Design space dimension P.
        conditioning: Optional conditioning (N, C).
        batch_size: Batch size for computation.

    Returns:
        FM loss uncertainties (N, K).
    """
    device = candidates.device
    n_targets, k, _ = candidates.shape
    fm_uncertainties = []

    for ki in tqdm(range(k), desc="Computing FM loss"):
        cand_k = candidates[:, ki, :]
        noise = torch.rand(n_targets, num_design_params, device=device)
        augmented_labels = torch.cat([target_labels, noise], dim=1)

        fm_loss = compute_fm_losses(
            simulated_designs=cand_k,
            augmented_labels=augmented_labels,
            model=model,
            num_labels=num_labels,
            device=device,
            conditioning=conditioning,
            batch_size=batch_size,
        )
        fm_uncertainties.append(fm_loss)

    return torch.stack(fm_uncertainties, dim=1)


def compute_candidate_uncertainties_zero_deviation(
    model: torch.nn.Module,
    target_labels: torch.Tensor,
    num_design_params: int,
    num_labels: int,
    k: int,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute zero-deviation uncertainty for K candidates per target.

    This exploits the Diag-CFM architecture: during synthesis (t=1 to t=0),
    the first L components should go from y to ~0. The deviation from zero
    indicates generation uncertainty.

    State convention:
    - t=1: [y; z] (labels + noise)
    - t=0: [~0; x] (near-zeros + design)

    Args:
        model: Flow matching model.
        target_labels: Target labels (N, L).
        num_design_params: Design space dimension P.
        num_labels: Label space dimension L.
        k: Number of candidates per target.
        conditioning: Optional conditioning (N, C).
        steps: Number of Euler integration steps.

    Returns:
        Tuple of:
        - Zero-deviation uncertainties (N, K)
        - Generated candidates (N, K, P)
    """
    device = target_labels.device
    n_targets = target_labels.shape[0]
    uncertainties = []
    all_candidates = []

    for _ in range(k):
        # Generate noise for this batch
        noise = torch.rand(n_targets, num_design_params, device=device)

        # Input state at t=1: [y; z]
        input_state = torch.cat([target_labels, noise], dim=1)

        # Synthesis: t=1 -> t=0
        if conditioning is not None:
            result = euler_method_with_conditioning(
                model=model,
                input=input_state,
                conditioning=conditioning,
                start_t=1,
                end_t=0,
                steps=steps,
            )
        else:
            result = euler_method(
                model=model,
                input=input_state,
                start_t=1,
                end_t=0,
                steps=steps,
            )

        # Output at t=0: [~0; x]
        zero_part = result[:, :num_labels]  # Should be ~0
        design = result[:, num_labels:]  # Generated design

        # Zero-deviation: how far from zero
        zero_deviation = (zero_part**2).mean(dim=1)  # (N,)

        uncertainties.append(zero_deviation)
        all_candidates.append(design)

    return torch.stack(uncertainties, dim=1), torch.stack(all_candidates, dim=1)


def compute_candidate_uncertainties_self_consistency(
    candidates: torch.Tensor,
    target_labels: torch.Tensor,
    model: torch.nn.Module,
    num_labels: int,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
) -> torch.Tensor:
    """Compute self-consistency uncertainty for each candidate.

    This feeds the generated design with actual zeros (not the model's ~0 output)
    through the analysis direction and checks if the target labels reconstruct.

    State convention:
    - t=0: [0; x] (actual zeros + design)
    - t=1: [y_reconstructed; ?] (reconstructed labels + ?)

    The uncertainty is ||y_reconstructed - y*||^2.

    Args:
        candidates: Candidate designs (N, K, P).
        target_labels: Target labels (N, L).
        model: Flow matching model.
        num_labels: Label space dimension L.
        conditioning: Optional conditioning (N, C).
        steps: Number of Euler integration steps.

    Returns:
        Self-consistency uncertainties (N, K).
    """
    device = candidates.device
    n_targets, k, num_design_params = candidates.shape
    uncertainties = []

    for ki in range(k):
        cand_k = candidates[:, ki, :]  # (N, P)

        # Create input at t=0: [0; x] (actual zeros, not model's output)
        zeros = torch.zeros(n_targets, num_labels, device=device)
        input_state = torch.cat([zeros, cand_k], dim=1)

        # Analysis: t=0 -> t=1
        if conditioning is not None:
            result = euler_method_with_conditioning(
                model=model,
                input=input_state,
                conditioning=conditioning,
                start_t=0,
                end_t=1,
                steps=steps,
            )
        else:
            result = euler_method(
                model=model,
                input=input_state,
                start_t=0,
                end_t=1,
                steps=steps,
            )

        # Output at t=1: [y_reconstructed; ?]
        y_reconstructed = result[:, :num_labels]

        # Self-consistency error: how well does it reconstruct the target
        self_consistency_error = ((y_reconstructed - target_labels) ** 2).mean(dim=1)

        uncertainties.append(self_consistency_error)

    return torch.stack(uncertainties, dim=1)  # (N, K)


def compute_roundtrip_errors(
    candidates: torch.Tensor,
    target_labels: torch.Tensor,
    forward_fn: Callable[[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
    """Compute round-trip error for each candidate using ground-truth forward function.

    Args:
        candidates: Candidate designs (N, K, P).
        target_labels: Target labels (N, L).
        forward_fn: Function that maps designs to labels (ground truth).

    Returns:
        Round-trip errors (N, K).
    """
    n_targets, k, _ = candidates.shape
    errors = []

    for ki in range(k):
        cand_k = candidates[:, ki, :]  # (N, P)
        achieved_labels = forward_fn(cand_k)  # (N, L)
        rt_error = ((achieved_labels - target_labels) ** 2).mean(dim=1)  # (N,)
        errors.append(rt_error)

    return torch.stack(errors, dim=1)  # (N, K)


# =============================================================================
# Main Evaluation Function
# =============================================================================


def evaluate_select_best(
    models: List[torch.nn.Module],
    target_labels: torch.Tensor,
    forward_fn: Callable[[torch.Tensor], torch.Tensor],
    num_design_params: int,
    num_labels: int,
    k: Optional[int] = None,
    n_samples: Optional[int] = None,
    conditioning: Optional[torch.Tensor] = None,
) -> Dict:
    """Run select-best evaluation with all uncertainty metrics.

    Computes four uncertainty metrics:
    - Zero-Deviation: Diag-CFM specific, measures deviation from zero at t=0
    - Self-Consistency: Model round-trip consistency
    - Ensemble Variance: Variance of predictions across ensemble models
    - FM Loss: Flow matching loss at interpolation point

    Args:
        models: List of Diag-CFM models (ensemble). First model is used as reference.
        target_labels: Target labels (N, L).
        forward_fn: Ground-truth forward function for error computation.
        num_design_params: Design space dimension P.
        num_labels: Label space dimension L.
        k: Number of candidates per target.
        n_samples: Number of target samples to evaluate.
        conditioning: Optional conditioning (N, C).

    Returns:
        Dictionary with:
        - n_targets: Number of targets evaluated
        - k: Number of candidates per target
        - random_error: Mean error with random selection
        - oracle_error: Mean error with oracle (best) selection
        - oracle_improvement: % improvement of oracle over random
        - metrics: Dict with correlation and improvement for each UQ metric
    """
    # Use config defaults if not specified
    if k is None:
        k = _config["k_candidates"]
    if n_samples is None:
        n_samples = _config["n_samples"]

    device = target_labels.device
    target_labels = target_labels[:n_samples]
    if conditioning is not None:
        conditioning = conditioning[:n_samples]
    n_targets = target_labels.shape[0]

    # Reference model is the first (best by training loss)
    reference_model = models[0]

    print(f"  Using {n_targets} samples, K={k} candidates, {len(models)} ensemble models")

    # Compute zero-deviation and candidates together (they share synthesis pass)
    print("  Computing zero-deviation uncertainties...")
    zero_dev_uncertainties, candidates = compute_candidate_uncertainties_zero_deviation(
        model=reference_model,
        target_labels=target_labels,
        num_design_params=num_design_params,
        num_labels=num_labels,
        k=k,
        conditioning=conditioning,
    )

    # Compute round-trip errors using ground-truth forward function
    print("  Computing round-trip errors...")
    rt_errors = compute_roundtrip_errors(candidates, target_labels, forward_fn)

    # Compute self-consistency uncertainties
    print("  Computing self-consistency uncertainties...")
    self_consistency_uncertainties = compute_candidate_uncertainties_self_consistency(
        candidates=candidates,
        target_labels=target_labels,
        model=reference_model,
        num_labels=num_labels,
        conditioning=conditioning,
    )

    # Compute ensemble variance (using all models)
    print("  Computing ensemble variance...")
    ensemble_uncertainties = compute_candidate_uncertainties_ensemble(
        candidates=candidates,
        models=models,
        num_labels=num_labels,
        conditioning=conditioning,
    )

    # Compute FM loss
    print("  Computing FM loss...")
    fm_uncertainties = compute_candidate_uncertainties_fm(
        candidates=candidates,
        target_labels=target_labels,
        model=reference_model,
        num_labels=num_labels,
        num_design_params=num_design_params,
        conditioning=conditioning,
    )

    # Move to CPU for correlation computation
    rt_errors_cpu = rt_errors.detach().cpu()
    errors_flat = rt_errors_cpu.numpy().flatten()

    def to_numpy_flat(t):
        if isinstance(t, torch.Tensor):
            return t.detach().cpu().numpy().flatten()
        return t.flatten()

    zero_dev_flat = to_numpy_flat(zero_dev_uncertainties)
    self_cons_flat = to_numpy_flat(self_consistency_uncertainties)
    ensemble_flat = to_numpy_flat(ensemble_uncertainties)
    fm_flat = to_numpy_flat(fm_uncertainties)

    def safe_corrcoef(a, b):
        """Compute correlation, returning 0 if either array has zero variance."""
        if np.std(a) == 0 or np.std(b) == 0:
            return 0.0
        corr = np.corrcoef(a, b)[0, 1]
        return 0.0 if np.isnan(corr) else float(corr)

    # Selection function: pick candidate with lowest uncertainty
    def select_error(uncertainties):
        if isinstance(uncertainties, np.ndarray):
            uncertainties = torch.from_numpy(uncertainties).reshape(n_targets, k)
        uncertainties = uncertainties.to(device)
        _, selected = uncertainties.min(dim=1)
        selected_errors = rt_errors[torch.arange(n_targets, device=device), selected]
        return selected_errors.mean().item()

    random_error = rt_errors.mean().item()
    oracle_error = rt_errors.min(dim=1)[0].mean().item()

    results = {
        "n_targets": n_targets,
        "k": k,
        "num_ensemble_models": len(models),
        "random_error": random_error,
        "oracle_error": oracle_error,
        "oracle_improvement": (1 - oracle_error / random_error) * 100,
        "metrics": {
            "Zero-Deviation": {
                "correlation": safe_corrcoef(zero_dev_flat, errors_flat),
                "error": select_error(zero_dev_uncertainties),
            },
            "Self-Consistency": {
                "correlation": safe_corrcoef(self_cons_flat, errors_flat),
                "error": select_error(self_consistency_uncertainties),
            },
            "Ensemble Variance": {
                "correlation": safe_corrcoef(ensemble_flat, errors_flat),
                "error": select_error(ensemble_uncertainties),
            },
            "FM Loss": {
                "correlation": safe_corrcoef(fm_flat, errors_flat),
                "error": select_error(fm_uncertainties),
            },
        },
    }

    # Add improvements
    for name, data in results["metrics"].items():
        data["improvement"] = (1 - data["error"] / random_error) * 100

    return results


# =============================================================================
# Dataset-Specific Evaluation Functions
# =============================================================================


def evaluate_gas_turbine(device: torch.device, verbose: bool = True) -> Dict:
    """Evaluate select-best on Gas Turbine dataset.

    Args:
        device: Computation device.
        verbose: Whether to print progress.

    Returns:
        Results dictionary with dataset metadata.
    """
    from uq_diagcfm.data_utils_gas_turbine import (
        GasTurbineDataset,
        LEN_LABELS,
        LEN_PARAMETERS,
        make_surrogates,
    )
    from uq_diagcfm.ensembles import load_gas_turbine_diag_cfm_ensemble

    if verbose:
        print("Evaluating Gas Turbine...")

    # Load model ensemble (sorted by final train loss, best first)
    models, run_infos, checkpoint_names, criteria = load_gas_turbine_diag_cfm_ensemble(
        device=device, verbose=verbose
    )
    if verbose:
        print(f"  Using {len(models)} ensemble models, reference: {checkpoint_names[0]}")

    # Load test data
    test_dataset = GasTurbineDataset(split="test")
    all_labels = torch.stack(
        [test_dataset[i][1] for i in range(len(test_dataset))], dim=0
    )

    # Randomly sample test indices for reproducibility
    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(n_available, size=min(_config["n_samples"], n_available), replace=False)
    test_labels = all_labels[indices].to(device)
    if verbose:
        print(f"  Randomly sampled {len(test_labels)} test samples (seed={_config['random_seed']})")

    # Load surrogate models for ground-truth forward evaluation
    surrogate_Unmix_O, surrogate_IO_PD, surrogate_IFD1 = make_surrogates()
    surrogate_Unmix_O = surrogate_Unmix_O.to(device)
    surrogate_IO_PD = surrogate_IO_PD.to(device)
    surrogate_IFD1 = surrogate_IFD1.to(device)

    def forward_fn(designs):
        clamped = torch.clamp(designs, 0, 1)
        label_1 = surrogate_Unmix_O(clamped)
        label_2 = surrogate_IO_PD(clamped)
        label_3 = surrogate_IFD1(clamped)
        return torch.cat((label_1, label_2, label_3), dim=1)

    results = evaluate_select_best(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=LEN_PARAMETERS,
        num_labels=LEN_LABELS,
    )
    results["dataset"] = "Gas Turbine"
    results["num_design_params"] = LEN_PARAMETERS
    results["num_labels"] = LEN_LABELS
    return results


def evaluate_unifoil(device: torch.device, verbose: bool = True) -> Dict:
    """Evaluate select-best on Unifoil dataset.

    Unifoil uses conditional generation with physical parameters (angle of attack,
    Mach number).

    Args:
        device: Computation device.
        verbose: Whether to print progress.

    Returns:
        Results dictionary with dataset metadata.
    """
    from uq_diagcfm.data_utils_unifoil import (
        LEN_DESIGN_PARAMETERS,
        LEN_PHYSICAL_PERFORMANCE,
        UnifoilDataset,
        make_unifoil_surrogate,
    )
    from uq_diagcfm.ensembles import load_unifoil_diag_cfm_ensemble

    if verbose:
        print("Evaluating Unifoil...")

    num_labels = LEN_PHYSICAL_PERFORMANCE  # 3 (performance labels)

    # Load model ensemble
    models, run_infos, checkpoint_names, criteria = load_unifoil_diag_cfm_ensemble(
        device=device, verbose=verbose
    )
    if verbose:
        print(f"  Using {len(models)} ensemble models, reference: {checkpoint_names[0]}")

    # Load validation data (Unifoil has no test split)
    val_dataset = UnifoilDataset(split="val")
    all_labels = torch.stack(
        [torch.from_numpy(val_dataset[i][2]) for i in range(len(val_dataset))], dim=0
    )
    all_conditioning = torch.stack(
        [torch.from_numpy(val_dataset[i][1]) for i in range(len(val_dataset))], dim=0
    )

    # Randomly sample test indices
    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(n_available, size=min(_config["n_samples"], n_available), replace=False)
    test_labels = all_labels[indices].to(device)
    conditioning = all_conditioning[indices].to(device)
    if verbose:
        print(f"  Randomly sampled {len(test_labels)} test samples (seed={_config['random_seed']})")

    # Load surrogate for forward evaluation
    surrogate = make_unifoil_surrogate().to(device)
    surrogate.eval()

    def forward_fn(designs):
        # Surrogate takes (design_params, physical_params) -> performance
        surrogate_input = torch.cat([designs, conditioning[: designs.shape[0]]], dim=1)
        with torch.no_grad():
            return surrogate(surrogate_input)

    results = evaluate_select_best(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=LEN_DESIGN_PARAMETERS,
        num_labels=num_labels,
        conditioning=conditioning,
    )
    results["dataset"] = "Unifoil"
    results["num_design_params"] = LEN_DESIGN_PARAMETERS
    results["num_labels"] = num_labels
    return results


def evaluate_dtlz(
    device: torch.device, num_design_params: int, verbose: bool = True
) -> Optional[Dict]:
    """Evaluate select-best on DTLZ benchmark.

    Args:
        device: Computation device.
        num_design_params: Design space dimension P.
        verbose: Whether to print progress.

    Returns:
        Results dictionary with dataset metadata, or None if no models found.
    """
    from uq_diagcfm.data_utils_dtlz import DTLZDataset, make_dtlz_surrogate
    from uq_diagcfm.ensembles import load_dtlz_diag_cfm_ensemble

    if verbose:
        print(f"Evaluating DTLZ P={num_design_params}...")

    num_objectives = 3

    # Load model ensemble
    models, run_infos, checkpoint_names, criteria = load_dtlz_diag_cfm_ensemble(
        P=num_design_params, device=device, verbose=verbose
    )
    if len(models) == 0:
        if verbose:
            print(f"  No models found for DTLZ P={num_design_params}")
        return None
    if verbose:
        print(f"  Using {len(models)} ensemble models, reference: {checkpoint_names[0]}")

    # Load test data
    test_dataset = DTLZDataset(
        split="test",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name="dtlz2",
        normalize_labels=True,
        sampling_strategy="stratified",
    )
    all_labels = test_dataset.labels
    label_scale = test_dataset.label_scale

    # Randomly sample test indices
    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(n_available, size=min(_config["n_samples"], n_available), replace=False)
    test_labels = all_labels[indices].to(device)
    if verbose:
        print(f"  Randomly sampled {len(test_labels)} test samples (seed={_config['random_seed']})")

    # Analytical forward function
    base_forward = make_dtlz_surrogate("dtlz2", num_objectives)

    def forward_fn(designs):
        clamped = torch.clamp(designs, 0, 1)
        return base_forward(clamped) / label_scale

    results = evaluate_select_best(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=num_design_params,
        num_labels=num_objectives,
    )
    results["dataset"] = f"DTLZ P={num_design_params}"
    results["num_design_params"] = num_design_params
    results["num_labels"] = num_objectives
    return results


# =============================================================================
# Plotting Functions
# =============================================================================


def create_summary_plot(all_results: List[Dict], output_file: Path) -> None:
    """Create a summary bar plot comparing all datasets and metrics.

    Args:
        all_results: List of result dictionaries from evaluate_* functions.
        output_file: Path to save the plot.
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))

    datasets = [r["dataset"] for r in all_results]
    metrics = [
        "Zero-Deviation",
        "Self-Consistency",
        "Ensemble Variance",
        "FM Loss",
    ]
    colors = {
        "Zero-Deviation": "#9b59b6",  # Purple
        "Self-Consistency": "#e67e22",  # Orange
        "Ensemble Variance": "#e74c3c",  # Red
        "FM Loss": "#2ecc71",  # Green
    }

    # Plot 1: Correlations
    ax1 = axes[0]
    x = np.arange(len(datasets))
    width = 0.15  # Smaller width to fit 5 bars
    n_metrics = len(metrics)

    for i, metric in enumerate(metrics):
        correlations = [r["metrics"][metric]["correlation"] for r in all_results]
        ax1.bar(
            x + i * width, correlations, width, label=metric, color=colors[metric], alpha=0.8
        )

    ax1.set_ylabel("Correlation with Error", fontsize=12)
    ax1.set_xlabel("Dataset", fontsize=12)
    ax1.set_title("Uncertainty-Error Correlation", fontsize=13, fontweight="bold")
    ax1.set_xticks(x + width * (n_metrics - 1) / 2)
    ax1.set_xticklabels(datasets, rotation=15, ha="right")
    ax1.legend(loc="upper right", fontsize=9)
    ax1.axhline(y=0, color="black", linestyle="-", linewidth=0.5)
    ax1.grid(True, alpha=0.3, axis="y")

    # Plot 2: Improvements
    ax2 = axes[1]

    for i, metric in enumerate(metrics):
        improvements = [r["metrics"][metric]["improvement"] for r in all_results]
        ax2.bar(
            x + i * width, improvements, width, label=metric, color=colors[metric], alpha=0.8
        )

    # Add oracle line for reference
    oracle_improvements = [r["oracle_improvement"] for r in all_results]
    ax2.scatter(
        x + width * (n_metrics - 1) / 2,
        oracle_improvements,
        color="gold",
        marker="*",
        s=150,
        zorder=5,
        label="Oracle",
    )

    ax2.set_ylabel("Improvement over Random (%)", fontsize=12)
    ax2.set_xlabel("Dataset", fontsize=12)
    ax2.set_title(
        f"Select-Best Performance (N={_config['n_samples']}, K={_config['k_candidates']})",
        fontsize=13,
        fontweight="bold",
    )
    ax2.set_xticks(x + width * (n_metrics - 1) / 2)
    ax2.set_xticklabels(datasets, rotation=15, ha="right")
    ax2.legend(loc="upper right", fontsize=9)
    ax2.axhline(y=0, color="black", linestyle="-", linewidth=0.5)
    ax2.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Summary plot saved to: {output_file}")
    plt.close()


def create_individual_bar_plot(result: Dict, output_file: Path) -> None:
    """Create individual bar plot for a single dataset.

    Args:
        result: Result dictionary from evaluate_* function.
        output_file: Path to save the plot.
    """
    fig, ax = plt.subplots(figsize=(12, 5))

    methods = [
        "Random",
        "Zero-\nDeviation",
        "Self-\nConsistency",
        "Ensemble\nVariance",
        "FM Loss",
        "Oracle",
    ]
    errors = [
        result["random_error"],
        result["metrics"]["Zero-Deviation"]["error"],
        result["metrics"]["Self-Consistency"]["error"],
        result["metrics"]["Ensemble Variance"]["error"],
        result["metrics"]["FM Loss"]["error"],
        result["oracle_error"],
    ]
    colors = ["gray", "#9b59b6", "#e67e22", "#e74c3c", "#2ecc71", "gold"]

    bars = ax.bar(methods, errors, color=colors, alpha=0.8, edgecolor="black")

    # Add improvement labels
    random_err = result["random_error"]
    for i, (bar, err) in enumerate(zip(bars, errors)):
        height = bar.get_height()
        if i > 0:  # Skip random
            improvement = (1 - err / random_err) * 100
            color = "darkgreen" if improvement > 0 else "darkred"
            ax.annotate(
                f"{improvement:+.1f}%",
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha="center",
                va="bottom",
                fontsize=9,
                color=color,
                fontweight="bold",
            )

    ax.set_ylabel("Mean Round-Trip Error", fontsize=12)
    ax.set_title(
        f'Select-Best: {result["dataset"]} (N={result["n_targets"]}, K={result["k"]})',
        fontsize=13,
        fontweight="bold",
    )
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Plot saved to: {output_file}")
    plt.close()


def create_individual_bar_plots(all_results: List[Dict], output_dir: Path) -> None:
    """Create individual bar plots for all datasets.

    Args:
        all_results: List of result dictionaries.
        output_dir: Directory to save plots.
    """
    for result in all_results:
        dataset_name = result["dataset"].replace(" ", "_").replace("=", "")
        output_file = output_dir / f"select_best_{dataset_name}.png"
        create_individual_bar_plot(result, output_file)


# =============================================================================
# Results Printing and Saving
# =============================================================================


def print_results_table(all_results: List[Dict]) -> None:
    """Print a formatted results table.

    Args:
        all_results: List of result dictionaries.
    """
    print("\n" + "=" * 100)
    print(f"SUMMARY: Select-Best (N={_config['n_samples']} samples, K={_config['k_candidates']} candidates)")
    print("=" * 100)
    print(
        f"{'Dataset':<20} {'P':>5} {'L':>5} {'Metric':<20} {'Correlation':>12} {'Improvement':>12}"
    )
    print("-" * 100)

    metric_order = [
        "Zero-Deviation",
        "Self-Consistency",
        "Ensemble Variance",
        "FM Loss",
    ]

    for result in all_results:
        P = result.get("num_design_params", "?")
        L = result.get("num_labels", "?")
        for i, metric in enumerate(metric_order):
            data = result["metrics"][metric]
            dataset = result["dataset"] if i == 0 else ""
            p_str = str(P) if i == 0 else ""
            l_str = str(L) if i == 0 else ""
            print(
                f"{dataset:<20} {p_str:>5} {l_str:>5} {metric:<20} "
                f"{data['correlation']:>12.4f} {data['improvement']:>+11.1f}%"
            )
        print(
            f"{'':20} {'':>5} {'':>5} {'Oracle':<20} {'---':>12} "
            f"{result['oracle_improvement']:>+11.1f}%"
        )
        print()


def save_results(all_results: List[Dict], output_file: Path) -> None:
    """Save results to JSON file.

    Args:
        all_results: List of result dictionaries.
        output_file: Path to save JSON.
    """
    output_file.parent.mkdir(exist_ok=True, parents=True)
    with open(output_file, "w") as f:
        json.dump(all_results, f, indent=2)
    print(f"Results saved to: {output_file}")


# =============================================================================
# Memory Management
# =============================================================================


def clear_memory():
    """Clear GPU/MPS memory and run garbage collection."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
        torch.mps.empty_cache()


# =============================================================================
# CLI Interface
# =============================================================================


def main():
    """Main CLI entry point for select-best evaluation."""
    from uq_diagcfm.paths import PAPER_FIGURES_DIR, RESULTS_UQ_DIR, ensure_paper_dirs_exist

    parser = argparse.ArgumentParser(
        description="Run select-best evaluation with Diag-CFM uncertainty metrics."
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="all",
        choices=["all", "gas_turbine", "unifoil", "dtlz"],
        help="Dataset to evaluate (default: all)",
    )
    parser.add_argument(
        "--P",
        type=str,
        default="all",
        help="For DTLZ: design dimension P or 'all' for [12, 24, 50, 100] (default: all)",
    )
    parser.add_argument(
        "--n-samples",
        type=int,
        default=DEFAULT_N_SAMPLES,
        help=f"Number of target samples (default: {DEFAULT_N_SAMPLES})",
    )
    parser.add_argument(
        "--k",
        type=int,
        default=DEFAULT_K_CANDIDATES,
        help=f"Number of candidates per target (default: {DEFAULT_K_CANDIDATES})",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=DEFAULT_RANDOM_SEED,
        help=f"Random seed for reproducibility (default: {DEFAULT_RANDOM_SEED})",
    )
    parser.add_argument(
        "--no-plots",
        action="store_true",
        help="Skip generating plots",
    )

    args = parser.parse_args()

    # Update config from CLI arguments
    _config["n_samples"] = args.n_samples
    _config["k_candidates"] = args.k
    _config["random_seed"] = args.seed

    ensure_paper_dirs_exist()
    RESULTS_UQ_DIR.mkdir(exist_ok=True, parents=True)

    device = get_device()
    all_results = []

    # Determine which datasets to evaluate
    if args.dataset == "all":
        datasets = ["gas_turbine", "unifoil", "dtlz"]
    else:
        datasets = [args.dataset]

    # Determine DTLZ P values
    if args.P == "all":
        dtlz_p_values = [12, 24, 50, 100]
    else:
        dtlz_p_values = [int(args.P)]

    # Run evaluations
    for dataset in datasets:
        if dataset == "gas_turbine":
            all_results.append(evaluate_gas_turbine(device))
            clear_memory()

        elif dataset == "unifoil":
            all_results.append(evaluate_unifoil(device))
            clear_memory()

        elif dataset == "dtlz":
            for P in dtlz_p_values:
                result = evaluate_dtlz(device, P)
                if result:
                    all_results.append(result)
                clear_memory()

    # Print and save results
    print_results_table(all_results)
    save_results(all_results, RESULTS_UQ_DIR / "select_best_all_datasets.json")

    # Generate plots
    if not args.no_plots and all_results:
        create_summary_plot(all_results, PAPER_FIGURES_DIR / "select_best_summary.png")
        create_individual_bar_plots(all_results, PAPER_FIGURES_DIR)


if __name__ == "__main__":
    main()
