"""
Error-Rejection Curves (Abstention Analysis) for Uncertainty Quantification.

This module implements error-rejection curves that demonstrate how rejecting
high-uncertainty samples improves average error. A monotonically decreasing
curve indicates the model "knows when it doesn't know."

The experiment:
1. Generate one design per target label
2. Compute uncertainty scores using all 4 UQ metrics
3. Rank samples by uncertainty
4. Progressively reject the most uncertain samples
5. Plot mean error vs rejection rate

Usage:
    python -m uq_diagcfm.error_rejection --dataset gas_turbine
    python -m uq_diagcfm.error_rejection --dataset dtlz --P 50
"""

import argparse
import gc
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

from uq_diagcfm.solvers import euler_method, euler_method_with_conditioning
from uq_diagcfm.uq_evaluation_utils import compute_fm_losses
from uq_diagcfm.utils import get_device

# =============================================================================
# Default Constants
# =============================================================================

DEFAULT_N_SAMPLES = 1000
DEFAULT_RANDOM_SEED = 42

_config = {
    "n_samples": DEFAULT_N_SAMPLES,
    "random_seed": DEFAULT_RANDOM_SEED,
}


# =============================================================================
# Core Error-Rejection Computation
# =============================================================================


def compute_error_rejection_curve(
    uncertainty: np.ndarray,
    error: np.ndarray,
    rejection_rates: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """Compute error-rejection curve for a single uncertainty metric.

    Args:
        uncertainty: Uncertainty scores for each sample (N,).
        error: Error values for each sample (N,).
        rejection_rates: Rejection rates to evaluate (default: 0%, 5%, ..., 50%).

    Returns:
        Tuple of (rejection_rates, mean_errors).
    """
    if rejection_rates is None:
        rejection_rates = np.arange(0, 0.55, 0.05)

    n_samples = len(uncertainty)
    mean_errors = []

    # Sort indices by uncertainty (descending - highest uncertainty first)
    sorted_indices = np.argsort(uncertainty)[::-1]

    for rate in rejection_rates:
        n_reject = int(rate * n_samples)
        n_keep = n_samples - n_reject

        if n_keep > 0:
            # Keep samples with lowest uncertainty
            kept_indices = sorted_indices[n_reject:]
            mean_error = np.mean(error[kept_indices])
        else:
            mean_error = np.nan

        mean_errors.append(mean_error)

    return rejection_rates, np.array(mean_errors)


# =============================================================================
# Calibration Analysis
# =============================================================================


def compute_calibration_curve(
    uncertainty: np.ndarray,
    error: np.ndarray,
    n_bins: int = 10,
) -> Dict:
    """Compute calibration curve for a single uncertainty metric.

    Bins samples by uncertainty percentile and computes mean error per bin.
    A well-calibrated metric should show monotonically increasing error
    with increasing uncertainty.

    Args:
        uncertainty: Uncertainty scores for each sample (N,).
        error: Error values for each sample (N,).
        n_bins: Number of bins (default: 10 for deciles).

    Returns:
        Dictionary with bin_centers, mean_errors, bin_counts, and metrics.
    """
    from scipy import stats

    # Compute percentile-based bin edges
    percentiles = np.linspace(0, 100, n_bins + 1)
    bin_edges = np.percentile(uncertainty, percentiles)

    # Ensure unique bin edges (handle ties at boundaries)
    bin_edges = np.unique(bin_edges)
    if len(bin_edges) < 3:
        # Too many ties, fall back to equal-width bins
        bin_edges = np.linspace(uncertainty.min(), uncertainty.max(), n_bins + 1)

    actual_n_bins = len(bin_edges) - 1

    # Assign samples to bins
    bin_indices = np.digitize(uncertainty, bin_edges[1:-1])

    # Compute mean error and uncertainty per bin
    bin_centers = []
    mean_errors = []
    std_errors = []
    bin_counts = []
    mean_uncertainties = []

    for i in range(actual_n_bins):
        mask = bin_indices == i
        count = np.sum(mask)
        bin_counts.append(count)

        if count > 0:
            mean_errors.append(np.mean(error[mask]))
            std_errors.append(np.std(error[mask]) / np.sqrt(count) if count > 1 else 0)
            mean_uncertainties.append(np.mean(uncertainty[mask]))
            bin_centers.append((bin_edges[i] + bin_edges[i + 1]) / 2)
        else:
            mean_errors.append(np.nan)
            std_errors.append(np.nan)
            mean_uncertainties.append(np.nan)
            bin_centers.append((bin_edges[i] + bin_edges[i + 1]) / 2)

    mean_errors = np.array(mean_errors)
    std_errors = np.array(std_errors)
    bin_counts = np.array(bin_counts)
    mean_uncertainties = np.array(mean_uncertainties)

    # Compute calibration metrics
    valid_mask = ~np.isnan(mean_errors) & (bin_counts > 0)
    valid_errors = mean_errors[valid_mask]
    valid_uncertainties = mean_uncertainties[valid_mask]

    # Spearman correlation (should be positive and high if well-calibrated)
    if len(valid_errors) > 2:
        spearman_corr, spearman_p = stats.spearmanr(valid_uncertainties, valid_errors)
    else:
        spearman_corr, spearman_p = np.nan, np.nan

    # Pearson correlation at sample level
    pearson_corr = np.corrcoef(uncertainty, error)[0, 1] if np.std(uncertainty) > 0 else 0.0

    # Monotonicity score: fraction of consecutive bins with increasing error
    if len(valid_errors) > 1:
        diffs = np.diff(valid_errors)
        monotonicity = np.mean(diffs > 0)
    else:
        monotonicity = np.nan

    # Expected Calibration Error (ECE): weighted average deviation from perfect calibration
    # For perfect calibration, normalized uncertainty should equal normalized error
    if len(valid_errors) > 1:
        # Normalize both to [0, 1] range
        norm_unc = (valid_uncertainties - valid_uncertainties.min()) / (
            valid_uncertainties.max() - valid_uncertainties.min() + 1e-10
        )
        norm_err = (valid_errors - valid_errors.min()) / (
            valid_errors.max() - valid_errors.min() + 1e-10
        )
        weights = bin_counts[valid_mask] / np.sum(bin_counts[valid_mask])
        ece = np.sum(weights * np.abs(norm_err - norm_unc))
    else:
        ece = np.nan

    return {
        "bin_centers": [float(x) for x in bin_centers],
        "mean_errors": [float(x) if not np.isnan(x) else None for x in mean_errors],
        "std_errors": [float(x) if not np.isnan(x) else None for x in std_errors],
        "bin_counts": [int(x) for x in bin_counts],
        "mean_uncertainties": [float(x) if not np.isnan(x) else None for x in mean_uncertainties],
        "n_bins": actual_n_bins,
        "spearman_correlation": float(spearman_corr) if not np.isnan(spearman_corr) else 0.0,
        "spearman_pvalue": float(spearman_p) if not np.isnan(spearman_p) else 1.0,
        "pearson_correlation": float(pearson_corr),
        "monotonicity": float(monotonicity) if not np.isnan(monotonicity) else 0.0,
        "ece": float(ece) if not np.isnan(ece) else 1.0,
    }


# =============================================================================
# Design Generation and Uncertainty Computation
# =============================================================================


def generate_designs_with_uncertainties(
    models: List[torch.nn.Module],
    target_labels: torch.Tensor,
    num_design_params: int,
    num_labels: int,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
    batch_size: int = 200,
) -> Dict[str, torch.Tensor]:
    """Generate one design per target and compute all uncertainty metrics.

    Args:
        models: List of ensemble models. First is reference.
        target_labels: Target labels (N, L).
        num_design_params: Design dimension P.
        num_labels: Label dimension L.
        conditioning: Optional conditioning (N, C).
        steps: Euler integration steps.
        batch_size: Batch size for ensemble/FM loss computation.

    Returns:
        Dictionary with:
        - designs: Generated designs (N, P)
        - zero_deviation: Zero-deviation uncertainties (N,)
        - self_consistency: Self-consistency uncertainties (N,)
        - ensemble_variance: Ensemble variance uncertainties (N,)
        - fm_loss: FM loss uncertainties (N,)
    """
    device = target_labels.device
    n_samples = target_labels.shape[0]
    reference_model = models[0]

    # Generate noise
    noise = torch.rand(n_samples, num_design_params, device=device)

    # Synthesis pass: [y; z] -> [~0; x]
    input_state = torch.cat([target_labels, noise], dim=1)

    print("  Generating designs (synthesis pass)...")
    with torch.no_grad():
        if conditioning is not None:
            result = euler_method_with_conditioning(
                model=reference_model,
                input=input_state,
                conditioning=conditioning,
                start_t=1,
                end_t=0,
                steps=steps,
            )
        else:
            result = euler_method(
                model=reference_model,
                input=input_state,
                start_t=1,
                end_t=0,
                steps=steps,
            )

    # Extract design and zero-deviation
    zero_part = result[:, :num_labels]
    designs = result[:, num_labels:]
    zero_deviation = (zero_part**2).mean(dim=1)

    # Self-consistency: [0; x] -> [y_rec; ?]
    print("  Computing self-consistency...")
    zeros = torch.zeros(n_samples, num_labels, device=device)
    analysis_input = torch.cat([zeros, designs], dim=1)

    with torch.no_grad():
        if conditioning is not None:
            analysis_result = euler_method_with_conditioning(
                model=reference_model,
                input=analysis_input,
                conditioning=conditioning,
                start_t=0,
                end_t=1,
                steps=steps,
            )
        else:
            analysis_result = euler_method(
                model=reference_model,
                input=analysis_input,
                start_t=0,
                end_t=1,
                steps=steps,
            )

    y_reconstructed = analysis_result[:, :num_labels]
    self_consistency = ((y_reconstructed - target_labels) ** 2).mean(dim=1)

    # Ensemble variance
    print("  Computing ensemble variance...")
    ensemble_variance = _compute_ensemble_variance(
        designs=designs,
        models=models,
        num_labels=num_labels,
        conditioning=conditioning,
        steps=steps,
        batch_size=batch_size,
    )

    # FM loss
    print("  Computing FM loss...")
    fm_loss = _compute_fm_loss(
        designs=designs,
        target_labels=target_labels,
        noise=noise,
        model=reference_model,
        num_labels=num_labels,
        conditioning=conditioning,
        batch_size=batch_size,
    )

    return {
        "designs": designs,
        "zero_deviation": zero_deviation,
        "self_consistency": self_consistency,
        "ensemble_variance": ensemble_variance,
        "fm_loss": fm_loss,
    }


def _compute_ensemble_variance(
    designs: torch.Tensor,
    models: List[torch.nn.Module],
    num_labels: int,
    conditioning: Optional[torch.Tensor],
    steps: int,
    batch_size: int,
) -> torch.Tensor:
    """Compute ensemble variance for each design."""
    device = designs.device
    n_samples = designs.shape[0]
    all_variances = []

    for batch_start in range(0, n_samples, batch_size):
        batch_end = min(batch_start + batch_size, n_samples)
        batch_designs = designs[batch_start:batch_end]
        batch_cond = None
        if conditioning is not None:
            batch_cond = conditioning[batch_start:batch_end]

        # Analysis input: [0; x]
        zeros = torch.zeros(batch_designs.shape[0], num_labels, device=device)
        analysis_input = torch.cat([zeros, batch_designs], dim=1)

        predictions = []
        for m in models:
            with torch.no_grad():
                if batch_cond is not None:
                    pred = euler_method_with_conditioning(
                        model=m,
                        input=analysis_input,
                        conditioning=batch_cond,
                        start_t=0,
                        end_t=1,
                        steps=steps,
                    )[:, :num_labels]
                else:
                    pred = euler_method(
                        model=m,
                        input=analysis_input,
                        start_t=0,
                        end_t=1,
                        steps=steps,
                    )[:, :num_labels]

                pred = torch.nan_to_num(pred, nan=0.0, posinf=1e6, neginf=-1e6)
                predictions.append(pred)

        # Stack: (batch, L, M) -> variance over models, mean over labels
        ensemble_preds = torch.stack(predictions, dim=2)
        batch_variance = ensemble_preds.var(dim=2).mean(dim=1)
        all_variances.append(batch_variance)

    return torch.cat(all_variances)


def _compute_fm_loss(
    designs: torch.Tensor,
    target_labels: torch.Tensor,
    noise: torch.Tensor,
    model: torch.nn.Module,
    num_labels: int,
    conditioning: Optional[torch.Tensor],
    batch_size: int,
) -> torch.Tensor:
    """Compute FM loss for each design."""
    device = designs.device

    # Prepare augmented labels for FM loss computation
    augmented_labels = torch.cat([target_labels, noise], dim=1)

    fm_loss = compute_fm_losses(
        simulated_designs=designs,
        augmented_labels=augmented_labels,
        model=model,
        num_labels=num_labels,
        device=device,
        conditioning=conditioning,
        batch_size=batch_size,
    )

    return fm_loss


# =============================================================================
# Main Evaluation Function
# =============================================================================


def evaluate_error_rejection(
    models: List[torch.nn.Module],
    target_labels: torch.Tensor,
    forward_fn: Callable[[torch.Tensor], torch.Tensor],
    num_design_params: int,
    num_labels: int,
    n_samples: Optional[int] = None,
    conditioning: Optional[torch.Tensor] = None,
    rejection_rates: Optional[np.ndarray] = None,
) -> Dict:
    """Run error-rejection evaluation with all uncertainty metrics.

    Args:
        models: List of ensemble models.
        target_labels: Target labels (N, L).
        forward_fn: Ground-truth forward function.
        num_design_params: Design dimension P.
        num_labels: Label dimension L.
        n_samples: Number of samples to evaluate.
        conditioning: Optional conditioning (N, C).
        rejection_rates: Rejection rates to evaluate.

    Returns:
        Dictionary with curves and statistics for all metrics.
    """
    if n_samples is None:
        n_samples = _config["n_samples"]
    if rejection_rates is None:
        rejection_rates = np.arange(0, 0.55, 0.05)

    device = target_labels.device
    target_labels = target_labels[:n_samples]
    if conditioning is not None:
        conditioning = conditioning[:n_samples]

    print(f"  Using {n_samples} samples, {len(models)} ensemble models")

    # Generate designs and compute all uncertainties
    data = generate_designs_with_uncertainties(
        models=models,
        target_labels=target_labels,
        num_design_params=num_design_params,
        num_labels=num_labels,
        conditioning=conditioning,
    )

    # Compute ground-truth errors
    print("  Computing round-trip errors...")
    with torch.no_grad():
        achieved_labels = forward_fn(data["designs"])
        errors = ((achieved_labels - target_labels) ** 2).mean(dim=1)

    # Move to numpy
    errors_np = errors.detach().cpu().numpy()

    def to_numpy(t):
        return t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t

    # Compute curves for all metrics
    metrics = [
        "zero_deviation",
        "self_consistency",
        "ensemble_variance",
        "fm_loss",
    ]
    metric_names = {
        "zero_deviation": "Zero-Deviation",
        "self_consistency": "Self-Consistency",
        "ensemble_variance": "Ensemble Variance",
        "fm_loss": "FM Loss",
    }

    curves = {}
    for metric in metrics:
        uncertainty = to_numpy(data[metric])
        rates, mean_errors = compute_error_rejection_curve(
            uncertainty, errors_np, rejection_rates
        )
        curves[metric_names[metric]] = {
            "rejection_rates": rates.tolist(),
            "mean_errors": mean_errors.tolist(),
            "correlation": float(np.corrcoef(uncertainty, errors_np)[0, 1])
            if np.std(uncertainty) > 0
            else 0.0,
        }

    # Random baseline
    random_errors = []
    for _ in range(100):
        random_unc = np.random.rand(len(errors_np))
        _, rand_err = compute_error_rejection_curve(random_unc, errors_np, rejection_rates)
        random_errors.append(rand_err)
    curves["Random"] = {
        "rejection_rates": rejection_rates.tolist(),
        "mean_errors": np.mean(random_errors, axis=0).tolist(),
        "correlation": 0.0,
    }

    # Oracle (sort by actual error)
    oracle_rates, oracle_errors = compute_error_rejection_curve(
        errors_np, errors_np, rejection_rates
    )
    curves["Oracle"] = {
        "rejection_rates": oracle_rates.tolist(),
        "mean_errors": oracle_errors.tolist(),
        "correlation": 1.0,
    }

    # Compute reduction statistics at 20% rejection
    baseline_error = curves["Random"]["mean_errors"][0]
    for name, curve in curves.items():
        idx_20 = np.argmin(np.abs(np.array(curve["rejection_rates"]) - 0.20))
        error_at_20 = curve["mean_errors"][idx_20]
        curve["error_reduction_at_20pct"] = (1 - error_at_20 / baseline_error) * 100

    return {
        "n_samples": n_samples,
        "num_ensemble_models": len(models),
        "baseline_error": baseline_error,
        "curves": curves,
    }


def evaluate_calibration(
    models: List[torch.nn.Module],
    target_labels: torch.Tensor,
    forward_fn: Callable[[torch.Tensor], torch.Tensor],
    num_design_params: int,
    num_labels: int,
    n_samples: Optional[int] = None,
    conditioning: Optional[torch.Tensor] = None,
    n_bins: int = 10,
) -> Dict:
    """Run calibration analysis with all uncertainty metrics.

    Args:
        models: List of ensemble models.
        target_labels: Target labels (N, L).
        forward_fn: Ground-truth forward function.
        num_design_params: Design dimension P.
        num_labels: Label dimension L.
        n_samples: Number of samples to evaluate.
        conditioning: Optional conditioning (N, C).
        n_bins: Number of calibration bins.

    Returns:
        Dictionary with calibration curves and metrics for all metrics.
    """
    if n_samples is None:
        n_samples = _config["n_samples"]

    device = target_labels.device
    target_labels = target_labels[:n_samples]
    if conditioning is not None:
        conditioning = conditioning[:n_samples]

    print(f"  Using {n_samples} samples, {len(models)} ensemble models")

    # Generate designs and compute all uncertainties
    data = generate_designs_with_uncertainties(
        models=models,
        target_labels=target_labels,
        num_design_params=num_design_params,
        num_labels=num_labels,
        conditioning=conditioning,
    )

    # Compute ground-truth errors
    print("  Computing round-trip errors...")
    with torch.no_grad():
        achieved_labels = forward_fn(data["designs"])
        errors = ((achieved_labels - target_labels) ** 2).mean(dim=1)

    # Move to numpy
    errors_np = errors.detach().cpu().numpy()

    def to_numpy(t):
        return t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t

    # Compute calibration for all metrics
    metrics = [
        "zero_deviation",
        "self_consistency",
        "ensemble_variance",
        "fm_loss",
    ]
    metric_names = {
        "zero_deviation": "Zero-Deviation",
        "self_consistency": "Self-Consistency",
        "ensemble_variance": "Ensemble Variance",
        "fm_loss": "FM Loss",
    }

    calibration = {}
    for metric in metrics:
        uncertainty = to_numpy(data[metric])
        cal_result = compute_calibration_curve(uncertainty, errors_np, n_bins)
        calibration[metric_names[metric]] = cal_result

    return {
        "n_samples": n_samples,
        "num_ensemble_models": len(models),
        "n_bins": n_bins,
        "mean_error": float(np.mean(errors_np)),
        "calibration": calibration,
    }


# =============================================================================
# Plotting Functions
# =============================================================================


def create_error_rejection_plot(
    results: Dict,
    output_file: Path,
    title: Optional[str] = None,
) -> None:
    """Create error-rejection curve plot for all UQ metrics.

    Args:
        results: Results from evaluate_error_rejection().
        output_file: Path to save the plot.
        title: Plot title.
    """
    fig, ax = plt.subplots(figsize=(10, 6))

    # Style configuration
    styles = {
        "Zero-Deviation": {"color": "#9b59b6", "linestyle": "-", "linewidth": 2.5},
        "Self-Consistency": {"color": "#e67e22", "linestyle": "-", "linewidth": 2.5},
        "Ensemble Variance": {"color": "#e74c3c", "linestyle": "--", "linewidth": 2},

        "FM Loss": {"color": "#2ecc71", "linestyle": "--", "linewidth": 2},
        "Random": {"color": "gray", "linestyle": ":", "linewidth": 1.5, "alpha": 0.7},
        "Oracle": {"color": "gold", "linestyle": "-", "linewidth": 2, "alpha": 0.8},
    }

    # Plot order: Diag-CFM specific first, then general, then baselines
    plot_order = [
        "Zero-Deviation",
        "Self-Consistency",
        "Ensemble Variance",

        "FM Loss",
        "Random",
        "Oracle",
    ]

    for name in plot_order:
        if name not in results["curves"]:
            continue
        curve = results["curves"][name]
        style = styles.get(name, {"color": "black", "linestyle": "-", "linewidth": 1})

        rates = np.array(curve["rejection_rates"]) * 100
        errors = curve["mean_errors"]
        reduction = curve.get("error_reduction_at_20pct", 0)

        if name in ["Random", "Oracle"]:
            label = name
        else:
            label = f"{name} ({reduction:+.1f}% at 20%)"

        ax.plot(rates, errors, label=label, **style)

    ax.set_xlabel("Rejection Rate (%)", fontsize=12)
    ax.set_ylabel("Mean Round-Trip Error", fontsize=12)

    if title is None:
        title = f"Error-Rejection Curves (N={results['n_samples']})"
    ax.set_title(title, fontsize=13, fontweight="bold")

    ax.legend(loc="upper right", fontsize=9)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 50)

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Error-rejection plot saved to: {output_file}")
    plt.close()


def create_calibration_plot(
    results: Dict,
    output_file: Path,
    title: Optional[str] = None,
) -> None:
    """Create calibration reliability diagram for all UQ metrics.

    Args:
        results: Results from evaluate_calibration().
        output_file: Path to save the plot.
        title: Plot title.
    """
    fig, ax = plt.subplots(figsize=(10, 6))

    # Style configuration
    styles = {
        "Zero-Deviation": {"color": "#9b59b6", "marker": "o", "linestyle": "-", "linewidth": 2},
        "Self-Consistency": {"color": "#e67e22", "marker": "s", "linestyle": "-", "linewidth": 2},
        "Ensemble Variance": {"color": "#e74c3c", "marker": "^", "linestyle": "--", "linewidth": 1.5},

        "FM Loss": {"color": "#2ecc71", "marker": "v", "linestyle": "--", "linewidth": 1.5},
    }

    # Plot order: Diag-CFM specific first, then general
    plot_order = [
        "Zero-Deviation",
        "Self-Consistency",
        "Ensemble Variance",

        "FM Loss",
    ]

    for name in plot_order:
        if name not in results["calibration"]:
            continue
        cal = results["calibration"][name]
        style = styles.get(name, {"color": "black", "marker": "o", "linestyle": "-", "linewidth": 1})

        # Use bin index (1-10) as x-axis for clearer interpretation
        bin_indices = np.arange(1, len(cal["mean_errors"]) + 1)
        mean_errors = np.array(cal["mean_errors"])
        std_errors = np.array(cal["std_errors"])

        # Filter valid bins
        valid = ~np.isnan(mean_errors)
        if not np.any(valid):
            continue

        label = f"{name} (ρ={cal['spearman_correlation']:.2f}, mono={cal['monotonicity']:.0%})"

        ax.errorbar(
            bin_indices[valid],
            mean_errors[valid],
            yerr=std_errors[valid],
            label=label,
            capsize=3,
            **style,
        )

    ax.set_xlabel("Uncertainty Decile (1=lowest, 10=highest)", fontsize=12)
    ax.set_ylabel("Mean Round-Trip Error", fontsize=12)

    if title is None:
        title = f"Calibration Reliability Diagram (N={results['n_samples']})"
    ax.set_title(title, fontsize=13, fontweight="bold")

    ax.set_xticks(range(1, 11))
    ax.legend(loc="upper left", fontsize=9)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Calibration plot saved to: {output_file}")
    plt.close()


def create_combined_calibration_plot(
    all_results: List[Dict],
    output_file: Path,
) -> None:
    """Create a multi-panel calibration plot for all datasets.

    Args:
        all_results: List of results from evaluate_calibration() for each dataset.
        output_file: Path to save the plot.
    """
    n_datasets = len(all_results)
    n_cols = min(3, n_datasets)
    n_rows = (n_datasets + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
    if n_datasets == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_rows > 1 or n_cols > 1 else [axes]

    # Style configuration
    styles = {
        "Zero-Deviation": {"color": "#9b59b6", "marker": "o", "linestyle": "-", "linewidth": 2},
        "Self-Consistency": {"color": "#e67e22", "marker": "s", "linestyle": "-", "linewidth": 2},
        "Ensemble Variance": {"color": "#e74c3c", "marker": "^", "linestyle": "--", "linewidth": 1.5},

        "FM Loss": {"color": "#2ecc71", "marker": "v", "linestyle": "--", "linewidth": 1.5},
    }

    plot_order = ["Zero-Deviation", "Self-Consistency", "Ensemble Variance", "FM Loss"]

    for idx, results in enumerate(all_results):
        ax = axes[idx]
        dataset_name = results.get("dataset", f"Dataset {idx + 1}")

        for name in plot_order:
            if name not in results["calibration"]:
                continue
            cal = results["calibration"][name]
            style = styles.get(name, {"color": "black", "marker": "o", "linestyle": "-", "linewidth": 1})

            bin_indices = np.arange(1, len(cal["mean_errors"]) + 1)
            mean_errors = np.array(cal["mean_errors"])

            valid = ~np.isnan(mean_errors)
            if not np.any(valid):
                continue

            ax.plot(
                bin_indices[valid],
                mean_errors[valid],
                label=name if idx == 0 else "",
                **style,
            )

        ax.set_xlabel("Uncertainty Decile", fontsize=10)
        ax.set_ylabel("Mean Error", fontsize=10)
        ax.set_title(dataset_name, fontsize=11, fontweight="bold")
        ax.set_xticks(range(1, 11))
        ax.grid(True, alpha=0.3)

    # Hide unused axes
    for idx in range(n_datasets, len(axes)):
        axes[idx].set_visible(False)

    # Add legend to first subplot
    if n_datasets > 0:
        axes[0].legend(loc="upper left", fontsize=8)

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Combined calibration plot saved to: {output_file}")
    plt.close()


# =============================================================================
# Dataset-Specific Evaluation Functions
# =============================================================================


def evaluate_gas_turbine(device: torch.device, verbose: bool = True) -> Dict:
    """Evaluate error-rejection on Gas Turbine dataset."""
    from uq_diagcfm.data_utils_gas_turbine import (
        GasTurbineDataset,
        LEN_LABELS,
        LEN_PARAMETERS,
        make_surrogates,
    )
    from uq_diagcfm.ensembles import load_gas_turbine_diag_cfm_ensemble

    if verbose:
        print("Evaluating Gas Turbine...")

    models, _, checkpoint_names, _ = load_gas_turbine_diag_cfm_ensemble(
        device=device, verbose=verbose
    )
    if verbose:
        print(f"  Using {len(models)} ensemble models")

    test_dataset = GasTurbineDataset(split="test")
    all_labels = torch.stack(
        [test_dataset[i][1] for i in range(len(test_dataset))], dim=0
    )

    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(
        n_available, size=min(_config["n_samples"], n_available), replace=False
    )
    test_labels = all_labels[indices].to(device)

    surrogate_Unmix_O, surrogate_IO_PD, surrogate_IFD1 = make_surrogates()
    surrogate_Unmix_O = surrogate_Unmix_O.to(device)
    surrogate_IO_PD = surrogate_IO_PD.to(device)
    surrogate_IFD1 = surrogate_IFD1.to(device)

    def forward_fn(designs):
        clamped = torch.clamp(designs, 0, 1)
        label_1 = surrogate_Unmix_O(clamped)
        label_2 = surrogate_IO_PD(clamped)
        label_3 = surrogate_IFD1(clamped)
        return torch.cat((label_1, label_2, label_3), dim=1)

    results = evaluate_error_rejection(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=LEN_PARAMETERS,
        num_labels=LEN_LABELS,
    )
    results["dataset"] = "Gas Turbine"
    return results


def evaluate_dtlz(
    device: torch.device, num_design_params: int, verbose: bool = True
) -> Optional[Dict]:
    """Evaluate error-rejection on DTLZ benchmark."""
    from uq_diagcfm.data_utils_dtlz import DTLZDataset, make_dtlz_surrogate
    from uq_diagcfm.ensembles import load_dtlz_diag_cfm_ensemble

    if verbose:
        print(f"Evaluating DTLZ P={num_design_params}...")

    num_objectives = 3

    models, _, checkpoint_names, _ = load_dtlz_diag_cfm_ensemble(
        P=num_design_params, device=device, verbose=verbose
    )
    if len(models) == 0:
        if verbose:
            print(f"  No models found for DTLZ P={num_design_params}")
        return None
    if verbose:
        print(f"  Using {len(models)} ensemble models")

    test_dataset = DTLZDataset(
        split="test",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name="dtlz2",
        normalize_labels=True,
        sampling_strategy="stratified",
    )
    all_labels = test_dataset.labels
    label_scale = test_dataset.label_scale

    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(
        n_available, size=min(_config["n_samples"], n_available), replace=False
    )
    test_labels = all_labels[indices].to(device)

    base_forward = make_dtlz_surrogate("dtlz2", num_objectives)

    def forward_fn(designs):
        clamped = torch.clamp(designs, 0, 1)
        return base_forward(clamped) / label_scale

    results = evaluate_error_rejection(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=num_design_params,
        num_labels=num_objectives,
    )
    results["dataset"] = f"DTLZ P={num_design_params}"
    return results


def evaluate_unifoil(device: torch.device, verbose: bool = True) -> Dict:
    """Evaluate error-rejection on Unifoil dataset."""
    from uq_diagcfm.data_utils_unifoil import (
        LEN_DESIGN_PARAMETERS,
        LEN_PHYSICAL_PERFORMANCE,
        UnifoilDataset,
        make_unifoil_surrogate,
    )
    from uq_diagcfm.ensembles import load_unifoil_diag_cfm_ensemble

    if verbose:
        print("Evaluating Unifoil...")

    num_labels = LEN_PHYSICAL_PERFORMANCE

    models, _, checkpoint_names, _ = load_unifoil_diag_cfm_ensemble(
        device=device, verbose=verbose
    )
    if verbose:
        print(f"  Using {len(models)} ensemble models")

    val_dataset = UnifoilDataset(split="val")
    all_labels = torch.stack(
        [torch.from_numpy(val_dataset[i][2]) for i in range(len(val_dataset))], dim=0
    )
    all_conditioning = torch.stack(
        [torch.from_numpy(val_dataset[i][1]) for i in range(len(val_dataset))], dim=0
    )

    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(
        n_available, size=min(_config["n_samples"], n_available), replace=False
    )
    test_labels = all_labels[indices].to(device)
    conditioning = all_conditioning[indices].to(device)

    surrogate = make_unifoil_surrogate().to(device)
    surrogate.eval()

    def forward_fn(designs):
        surrogate_input = torch.cat([designs, conditioning[: designs.shape[0]]], dim=1)
        with torch.no_grad():
            return surrogate(surrogate_input)

    results = evaluate_error_rejection(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=LEN_DESIGN_PARAMETERS,
        num_labels=num_labels,
        conditioning=conditioning,
    )
    results["dataset"] = "Unifoil"
    return results


# =============================================================================
# Dataset-Specific Calibration Evaluation Functions
# =============================================================================


def calibrate_gas_turbine(device: torch.device, verbose: bool = True) -> Dict:
    """Evaluate calibration on Gas Turbine dataset."""
    from uq_diagcfm.data_utils_gas_turbine import (
        GasTurbineDataset,
        LEN_LABELS,
        LEN_PARAMETERS,
        make_surrogates,
    )
    from uq_diagcfm.ensembles import load_gas_turbine_diag_cfm_ensemble

    if verbose:
        print("Calibration Analysis: Gas Turbine...")

    models, _, checkpoint_names, _ = load_gas_turbine_diag_cfm_ensemble(
        device=device, verbose=verbose
    )
    if verbose:
        print(f"  Using {len(models)} ensemble models")

    test_dataset = GasTurbineDataset(split="test")
    all_labels = torch.stack(
        [test_dataset[i][1] for i in range(len(test_dataset))], dim=0
    )

    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(
        n_available, size=min(_config["n_samples"], n_available), replace=False
    )
    test_labels = all_labels[indices].to(device)

    surrogate_Unmix_O, surrogate_IO_PD, surrogate_IFD1 = make_surrogates()
    surrogate_Unmix_O = surrogate_Unmix_O.to(device)
    surrogate_IO_PD = surrogate_IO_PD.to(device)
    surrogate_IFD1 = surrogate_IFD1.to(device)

    def forward_fn(designs):
        clamped = torch.clamp(designs, 0, 1)
        label_1 = surrogate_Unmix_O(clamped)
        label_2 = surrogate_IO_PD(clamped)
        label_3 = surrogate_IFD1(clamped)
        return torch.cat((label_1, label_2, label_3), dim=1)

    results = evaluate_calibration(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=LEN_PARAMETERS,
        num_labels=LEN_LABELS,
    )
    results["dataset"] = "Gas Turbine"
    return results


def calibrate_dtlz(
    device: torch.device, num_design_params: int, verbose: bool = True
) -> Optional[Dict]:
    """Evaluate calibration on DTLZ benchmark."""
    from uq_diagcfm.data_utils_dtlz import DTLZDataset, make_dtlz_surrogate
    from uq_diagcfm.ensembles import load_dtlz_diag_cfm_ensemble

    if verbose:
        print(f"Calibration Analysis: DTLZ P={num_design_params}...")

    num_objectives = 3

    models, _, checkpoint_names, _ = load_dtlz_diag_cfm_ensemble(
        P=num_design_params, device=device, verbose=verbose
    )
    if len(models) == 0:
        if verbose:
            print(f"  No models found for DTLZ P={num_design_params}")
        return None
    if verbose:
        print(f"  Using {len(models)} ensemble models")

    test_dataset = DTLZDataset(
        split="test",
        num_design_params=num_design_params,
        num_objectives=num_objectives,
        function_name="dtlz2",
        normalize_labels=True,
        sampling_strategy="stratified",
    )
    all_labels = test_dataset.labels
    label_scale = test_dataset.label_scale

    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(
        n_available, size=min(_config["n_samples"], n_available), replace=False
    )
    test_labels = all_labels[indices].to(device)

    base_forward = make_dtlz_surrogate("dtlz2", num_objectives)

    def forward_fn(designs):
        clamped = torch.clamp(designs, 0, 1)
        return base_forward(clamped) / label_scale

    results = evaluate_calibration(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=num_design_params,
        num_labels=num_objectives,
    )
    results["dataset"] = f"DTLZ P={num_design_params}"
    return results


def calibrate_unifoil(device: torch.device, verbose: bool = True) -> Dict:
    """Evaluate calibration on Unifoil dataset."""
    from uq_diagcfm.data_utils_unifoil import (
        LEN_DESIGN_PARAMETERS,
        LEN_PHYSICAL_PERFORMANCE,
        UnifoilDataset,
        make_unifoil_surrogate,
    )
    from uq_diagcfm.ensembles import load_unifoil_diag_cfm_ensemble

    if verbose:
        print("Calibration Analysis: Unifoil...")

    num_labels = LEN_PHYSICAL_PERFORMANCE

    models, _, checkpoint_names, _ = load_unifoil_diag_cfm_ensemble(
        device=device, verbose=verbose
    )
    if verbose:
        print(f"  Using {len(models)} ensemble models")

    val_dataset = UnifoilDataset(split="val")
    all_labels = torch.stack(
        [torch.from_numpy(val_dataset[i][2]) for i in range(len(val_dataset))], dim=0
    )
    all_conditioning = torch.stack(
        [torch.from_numpy(val_dataset[i][1]) for i in range(len(val_dataset))], dim=0
    )

    rng = np.random.default_rng(_config["random_seed"])
    n_available = len(all_labels)
    indices = rng.choice(
        n_available, size=min(_config["n_samples"], n_available), replace=False
    )
    test_labels = all_labels[indices].to(device)
    conditioning = all_conditioning[indices].to(device)

    surrogate = make_unifoil_surrogate().to(device)
    surrogate.eval()

    def forward_fn(designs):
        surrogate_input = torch.cat([designs, conditioning[: designs.shape[0]]], dim=1)
        with torch.no_grad():
            return surrogate(surrogate_input)

    results = evaluate_calibration(
        models=models,
        target_labels=test_labels,
        forward_fn=forward_fn,
        num_design_params=LEN_DESIGN_PARAMETERS,
        num_labels=num_labels,
        conditioning=conditioning,
    )
    results["dataset"] = "Unifoil"
    return results


# =============================================================================
# Memory Management
# =============================================================================


def clear_memory():
    """Clear GPU/MPS memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
        torch.mps.empty_cache()


# =============================================================================
# CLI Interface
# =============================================================================


def main():
    """CLI entry point for error-rejection and calibration evaluation."""
    import json

    from uq_diagcfm.paths import PAPER_FIGURES_DIR, RESULTS_UQ_DIR, ensure_paper_dirs_exist

    parser = argparse.ArgumentParser(
        description="Compute error-rejection curves or calibration analysis for UQ metrics."
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="error_rejection",
        choices=["error_rejection", "calibration", "both"],
        help="Analysis mode: error_rejection, calibration, or both",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="gas_turbine",
        choices=["gas_turbine", "unifoil", "dtlz", "all"],
        help="Dataset to evaluate (use 'all' for all datasets)",
    )
    parser.add_argument(
        "--P",
        type=int,
        default=50,
        help="For DTLZ: design dimension P",
    )
    parser.add_argument(
        "--n-samples",
        type=int,
        default=DEFAULT_N_SAMPLES,
        help=f"Number of samples (default: {DEFAULT_N_SAMPLES})",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=DEFAULT_RANDOM_SEED,
        help=f"Random seed (default: {DEFAULT_RANDOM_SEED})",
    )

    args = parser.parse_args()

    _config["n_samples"] = args.n_samples
    _config["random_seed"] = args.seed

    ensure_paper_dirs_exist()
    RESULTS_UQ_DIR.mkdir(exist_ok=True, parents=True)

    device = get_device()

    # Determine which datasets to evaluate
    if args.dataset == "all":
        datasets = ["gas_turbine", "unifoil"]
        # Add DTLZ with different P values
        for P in [12, 24, 50, 100]:
            datasets.append(f"dtlz_{P}")
    else:
        datasets = [args.dataset if args.dataset != "dtlz" else f"dtlz_{args.P}"]

    all_calibration_results = []

    for dataset in datasets:
        print(f"\n{'='*80}")
        print(f"Processing: {dataset}")
        print("=" * 80)

        # Run error-rejection if requested
        if args.mode in ["error_rejection", "both"]:
            if dataset == "gas_turbine":
                results = evaluate_gas_turbine(device)
                output_file = PAPER_FIGURES_DIR / "error_rejection_gas_turbine.png"
            elif dataset == "unifoil":
                results = evaluate_unifoil(device)
                output_file = PAPER_FIGURES_DIR / "error_rejection_unifoil.png"
            elif dataset.startswith("dtlz_"):
                P = int(dataset.split("_")[1])
                results = evaluate_dtlz(device, P)
                if results is None:
                    continue
                output_file = PAPER_FIGURES_DIR / f"error_rejection_dtlz_P{P}.png"
            else:
                continue

            # Print error-rejection summary
            print(f"\nError-Rejection Results: {results['dataset']}")
            print("-" * 60)
            print(f"{'Metric':<20} {'Correlation':>12} {'Reduction @ 20%':>15}")
            print("-" * 60)

            for name in [
                "Zero-Deviation",
                "Self-Consistency",
                "Ensemble Variance",
        
                "FM Loss",
                "Oracle",
            ]:
                if name in results["curves"]:
                    curve = results["curves"][name]
                    print(
                        f"{name:<20} {curve['correlation']:>12.4f} "
                        f"{curve['error_reduction_at_20pct']:>+14.1f}%"
                    )

            create_error_rejection_plot(
                results,
                output_file,
                title=f"Error-Rejection Curves: {results['dataset']}",
            )

            clear_memory()

        # Run calibration if requested
        if args.mode in ["calibration", "both"]:
            if dataset == "gas_turbine":
                cal_results = calibrate_gas_turbine(device)
                output_file = PAPER_FIGURES_DIR / "calibration_gas_turbine.png"
            elif dataset == "unifoil":
                cal_results = calibrate_unifoil(device)
                output_file = PAPER_FIGURES_DIR / "calibration_unifoil.png"
            elif dataset.startswith("dtlz_"):
                P = int(dataset.split("_")[1])
                cal_results = calibrate_dtlz(device, P)
                if cal_results is None:
                    continue
                output_file = PAPER_FIGURES_DIR / f"calibration_dtlz_P{P}.png"
            else:
                continue

            all_calibration_results.append(cal_results)

            # Print calibration summary
            print(f"\nCalibration Results: {cal_results['dataset']}")
            print("-" * 80)
            print(f"{'Metric':<20} {'Spearman ρ':>12} {'Monotonicity':>12} {'ECE':>10}")
            print("-" * 80)

            for name in [
                "Zero-Deviation",
                "Self-Consistency",
                "Ensemble Variance",
        
                "FM Loss",
            ]:
                if name in cal_results["calibration"]:
                    cal = cal_results["calibration"][name]
                    print(
                        f"{name:<20} {cal['spearman_correlation']:>12.4f} "
                        f"{cal['monotonicity']:>11.0%} {cal['ece']:>10.4f}"
                    )

            create_calibration_plot(
                cal_results,
                output_file,
                title=f"Calibration: {cal_results['dataset']}",
            )

            clear_memory()

    # Create combined calibration plot if we have multiple datasets
    if args.mode in ["calibration", "both"] and len(all_calibration_results) > 1:
        combined_output = PAPER_FIGURES_DIR / "calibration_all_datasets.png"
        create_combined_calibration_plot(all_calibration_results, combined_output)

        # Save combined results as JSON
        json_output = RESULTS_UQ_DIR / "calibration_all_datasets.json"
        with open(json_output, "w") as f:
            json.dump(all_calibration_results, f, indent=2)
        print(f"\nCalibration results saved to: {json_output}")

        # Print summary table
        print("\n" + "=" * 100)
        print("CALIBRATION SUMMARY (All Datasets)")
        print("=" * 100)
        print(f"{'Dataset':<20} {'Zero-Dev ρ':>12} {'Self-Cons ρ':>12} {'Ens.Var ρ':>12} {'GradMag ρ':>12} {'FMLoss ρ':>12}")
        print("-" * 100)

        for res in all_calibration_results:
            row = f"{res['dataset']:<20}"
            for name in ["Zero-Deviation", "Self-Consistency", "Ensemble Variance", "FM Loss"]:
                if name in res["calibration"]:
                    row += f" {res['calibration'][name]['spearman_correlation']:>11.4f}"
                else:
                    row += f" {'N/A':>11}"
            print(row)


if __name__ == "__main__":
    main()
