"""
Common utilities for UQ evaluation across datasets.

This module provides shared functions for uncertainty quantification (UQ) evaluation
with out-of-distribution (OOD) detection. It consolidates common patterns from
dataset-specific evaluation scripts (Gas Turbine, Unifoil, DTLZ).

Key functions:
- Metric computation: compute_zero_deviation, compute_self_consistency
- FM loss computation: compute_fm_losses
- Ensemble computation: compute_ensemble_variance
- Main pipeline: evaluate_uq_with_ood_generic
- Visualization: create_violin_plots, create_roc_curves, create_point_cloud_plot
"""

import json
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import torch
from sklearn.metrics import auc, roc_curve, roc_auc_score
from tqdm import tqdm

from uq_diagcfm.ood_utils import get_ood_points_by_difficulty
from uq_diagcfm.solvers import euler_method, euler_method_with_conditioning
from uq_diagcfm.utils import get_device

# Canonical order for UQ metrics in plots (ROC curves, violin plots)
# This ensures consistent ordering across all datasets and visualization functions
UQ_METRIC_ORDER = [
    "zero_deviation",
    "self_consistency",
    "ensemble_variance",
    "fm_loss",
]


# =============================================================================
# Generic UQ Evaluation Functions
# =============================================================================


def sample_in_dist_and_ood(
    val_labels: torch.Tensor,
    ood_points: torch.Tensor,
    nb_samples: int,
    device: torch.device,
    seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
    """Randomly sample in-distribution and OOD points with a fixed seed.

    This function ensures reproducible sampling for both in-distribution
    (from validation set) and OOD points.

    Args:
        val_labels: Validation labels tensor (N_val, L).
        ood_points: OOD points tensor (N_ood, L).
        nb_samples: Number of samples to draw from each category.
        device: Computation device.
        seed: Random seed for reproducible sampling (default: 42).

    Returns:
        Tuple of:
        - in_dist_labels: Sampled in-distribution labels (nb_samples, L)
        - ood_labels: Sampled OOD labels (min(nb_samples, N_ood), L)
        - ood_mask: Binary mask (N_total,) where 1 = OOD, 0 = in-distribution
    """
    # Set seed for reproducible sampling
    torch.manual_seed(seed)

    # Randomly sample in-distribution points
    n_in_dist = min(nb_samples, val_labels.shape[0])
    in_dist_perm = torch.randperm(val_labels.shape[0])[:n_in_dist]
    in_dist_labels = val_labels[in_dist_perm].to(device)

    # Randomly sample OOD points
    n_ood = min(nb_samples, ood_points.shape[0])
    ood_perm = torch.randperm(ood_points.shape[0])[:n_ood]
    ood_labels = ood_points[ood_perm].to(device)

    # Create OOD mask
    total_samples = in_dist_labels.shape[0] + ood_labels.shape[0]
    ood_mask = np.zeros(total_samples, dtype=int)
    ood_mask[in_dist_labels.shape[0]:] = 1

    print(f"Sampled points (seed={seed}):")
    print(f"  In-distribution: {in_dist_labels.shape[0]}")
    print(f"  OOD: {ood_labels.shape[0]}")
    print(f"  Total: {total_samples}")

    return in_dist_labels, ood_labels, ood_mask


def create_y_complement(
    batch_size: int,
    num_design_params: int,
    num_labels: int,
    device: torch.device,
) -> torch.Tensor:
    """Create noise complement for augmenting labels in Diag-CFM.

    For Diag-CFM, the state at t=1 is [y, z] where z is the noise complement.

    Args:
        batch_size: Number of samples.
        num_design_params: Design space dimension P.
        num_labels: Label space dimension L.
        device: Computation device.

    Returns:
        Tensor of shape (batch_size, P) with random noise.
    """
    return torch.rand(batch_size, num_design_params, device=device)


def flow_matching_loss_for_uq(
    model_params: dict,
    x: torch.Tensor,
    augmented_y: torch.Tensor,
    noise: torch.Tensor,
    model: torch.nn.Module,
    device: torch.device,
    num_labels: int,
    conditioning: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Compute flow matching loss for a single sample (for UQ evaluation).

    This loss measures how well the model's learned dynamics support
    the transport from the augmented label to the generated design.

    Args:
        model_params: Model parameters dictionary.
        x: Generated design parameters (P,).
        augmented_y: Augmented label vector [y; noise] (L + P,).
        noise: Small noise for numerical stability (unused, kept for vmap signature).
        model: Flow matching model.
        device: Computation device.
        num_labels: Number of labels L.
        conditioning: Optional conditioning vector (C,) for conditional models.

    Returns:
        Scalar flow matching loss for this sample.
    """
    t = torch.tensor([0.5], device=device)

    # Create x_complement (zeros for Diag-CFM)
    x_complement = torch.zeros(num_labels, device=device)

    # Augment x: [0; x]
    augmented_x = torch.cat((x_complement, x), dim=0)

    # Interpolate at t=0.5
    model_inp = (1 - t) * augmented_x + t * augmented_y

    # Model prediction (with or without conditioning)
    if conditioning is not None:
        out = torch.func.functional_call(
            model, model_params, torch.cat((model_inp, conditioning, t), dim=0)
        )
    else:
        out = torch.func.functional_call(
            model, model_params, torch.cat((model_inp, t), dim=0)
        )

    # FM loss: ||v_theta(t, s_t) - (s_1 - s_0)||^2
    target = augmented_y - augmented_x
    # Compute on label dimensions only since OOD is defined in label space
    loss = torch.mean((out[:num_labels] - target[:num_labels]) ** 2)

    return loss


def compute_fm_losses(
    simulated_designs: torch.Tensor,
    augmented_labels: torch.Tensor,
    model: torch.nn.Module,
    num_labels: int,
    device: torch.device,
    conditioning: Optional[torch.Tensor] = None,
    batch_size: int = 100,
) -> torch.Tensor:
    """Compute FM losses for all samples (forward pass only, no gradient computation).

    Args:
        simulated_designs: Generated designs (N, P).
        augmented_labels: Augmented labels (N, L + P).
        model: Reference model.
        num_labels: Number of labels L.
        device: Computation device.
        conditioning: Optional conditioning (N, C) for conditional models.
        batch_size: Batch size for computation.

    Returns:
        FM losses tensor of shape (N,).
    """
    fm_losses = []
    noise = torch.randn(augmented_labels.shape[1], device=device) * 0.001
    params = dict(model.named_parameters())

    for i in tqdm(
        range(0, simulated_designs.shape[0], batch_size),
        desc="Computing FM losses",
    ):
        batch_x = simulated_designs[i : i + batch_size]
        batch_y = augmented_labels[i : i + batch_size]

        if conditioning is not None:
            batch_cond = conditioning[i : i + batch_size]

            def loss_fn_cond(params, x, y, n, m, d, nl, c):
                return flow_matching_loss_for_uq(params, x, y, n, m, d, nl, c)

            b_losses = torch.func.vmap(
                loss_fn_cond,
                in_dims=(None, 0, 0, None, None, None, None, 0),
                randomness="same",
            )(params, batch_x, batch_y, noise, model, device, num_labels, batch_cond)
        else:
            def loss_fn(params, x, y, n, m, d, nl):
                return flow_matching_loss_for_uq(params, x, y, n, m, d, nl, None)

            b_losses = torch.func.vmap(
                loss_fn,
                in_dims=(None, 0, 0, None, None, None, None),
                randomness="same",
            )(params, batch_x, batch_y, noise, model, device, num_labels)

        b_losses = b_losses.detach().cpu()
        max_finite_loss = 1e10
        b_losses = torch.where(
            torch.isfinite(b_losses),
            b_losses,
            torch.full_like(b_losses, max_finite_loss),
        )
        b_losses = torch.clamp(b_losses, max=max_finite_loss)
        fm_losses.append(b_losses)

    return torch.cat(fm_losses)


def compute_ensemble_variance(
    simulated_designs: torch.Tensor,
    models: List[torch.nn.Module],
    num_labels: int,
    device: torch.device,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
) -> torch.Tensor:
    """Compute ensemble variance across model predictions.

    Args:
        simulated_designs: Generated designs (N, P).
        models: List of ensemble models.
        num_labels: Number of labels L.
        device: Computation device.
        conditioning: Optional conditioning (N, C) for conditional models.
        steps: Number of integration steps.

    Returns:
        Ensemble variance tensor of shape (N,).
        NaN/inf values in predictions are replaced with large finite values.
    """
    # Augment designs with zeros for labels
    augmented_designs = torch.cat(
        [
            torch.zeros(simulated_designs.shape[0], num_labels, device=device),
            simulated_designs,
        ],
        dim=1,
    )

    # Compute predictions from all models
    predictions_list = []
    for m in models:
        if conditioning is not None:
            pred = euler_method_with_conditioning(
                model=m,
                input=augmented_designs,
                conditioning=conditioning,
                start_t=0,
                end_t=1,
                steps=steps,
            )[:, :num_labels]
        else:
            pred = euler_method(
                model=m, input=augmented_designs, start_t=0, end_t=1, steps=steps
            )[:, :num_labels]

        # Handle inf/nan in predictions for numerical stability
        pred = torch.nan_to_num(pred, nan=0.0, posinf=1e6, neginf=-1e6)
        predictions_list.append(pred)

    ensemble_predictions = torch.stack(predictions_list, dim=2)  # (N, L, num_models)

    # Ensemble variance: variance across models, averaged over labels
    ensemble_variance = ensemble_predictions.var(dim=2).mean(dim=1)  # (N,)

    # Handle any remaining inf/nan in variance computation
    # Use a reasonable max value to avoid overflow in statistics
    max_finite_var = 1e10
    ensemble_variance = torch.where(
        torch.isfinite(ensemble_variance),
        ensemble_variance,
        torch.full_like(ensemble_variance, max_finite_var),
    )
    ensemble_variance = torch.clamp(ensemble_variance, max=max_finite_var)

    return ensemble_variance


def compute_zero_deviation(
    augmented_output: torch.Tensor,
    num_labels: int,
) -> torch.Tensor:
    """Compute zero-deviation from synthesis output.

    For Diag-CFM, after synthesis (t=1 → t=0), the label dimensions should be near zero.
    Zero-deviation measures how far they deviate from zero.

    Args:
        augmented_output: Full output from synthesis pass (N, L + P).
        num_labels: Number of labels L.

    Returns:
        Zero-deviation ||output[:L]||² of shape (N,).
    """
    label_output = augmented_output[:, :num_labels]
    return (label_output ** 2).sum(dim=1)


def compute_self_consistency(
    model: torch.nn.Module,
    simulated_designs: torch.Tensor,
    target_labels: torch.Tensor,
    num_labels: int,
    device: torch.device,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
    batch_size: int = 500,
) -> torch.Tensor:
    """Compute self-consistency error via analysis pass.

    Self-consistency measures whether generated designs reconstruct the target
    labels when passed through the analysis direction (t=0 → t=1).

    Args:
        model: Flow matching model.
        simulated_designs: Generated designs of shape (N, P).
        target_labels: Target labels of shape (N, L).
        num_labels: Number of labels L.
        device: Computation device.
        conditioning: Optional conditioning tensor (N, C) for conditional models.
        steps: Number of Euler integration steps.
        batch_size: Batch size for processing.

    Returns:
        Self-consistency errors ||y_reconstructed - y*||² of shape (N,).
    """
    self_consistency_errors = []

    for i in tqdm(range(0, simulated_designs.shape[0], batch_size),
                  desc="Computing self-consistency"):
        batch_designs = simulated_designs[i:i + batch_size]
        batch_labels = target_labels[i:i + batch_size]

        # Create augmented input: [0; x] (zeros for labels, design for rest)
        augmented_input = torch.cat([
            torch.zeros(batch_designs.shape[0], num_labels, device=device),
            batch_designs,
        ], dim=1)

        # Run analysis pass (t=0 → t=1)
        if conditioning is not None:
            batch_cond = conditioning[i:i + batch_size]
            result = euler_method_with_conditioning(
                model=model,
                input=augmented_input,
                conditioning=batch_cond,
                start_t=0,
                end_t=1,
                steps=steps,
            )
        else:
            result = euler_method(
                model=model,
                input=augmented_input,
                start_t=0,
                end_t=1,
                steps=steps,
            )

        # Extract reconstructed labels
        y_reconstructed = result[:, :num_labels]

        # Handle numerical stability
        y_reconstructed = torch.nan_to_num(
            y_reconstructed, nan=0.0, posinf=1e6, neginf=-1e6
        )

        # Self-consistency error: MSE between reconstructed and target
        batch_errors = ((y_reconstructed - batch_labels) ** 2).mean(dim=1)

        # Handle numerical stability
        batch_errors = torch.nan_to_num(batch_errors, nan=1e10, posinf=1e10, neginf=0.0)
        batch_errors = torch.clamp(batch_errors, max=1e10)

        self_consistency_errors.append(batch_errors)

    return torch.cat(self_consistency_errors)


def generate_designs_from_labels(
    model: torch.nn.Module,
    augmented_labels: torch.Tensor,
    num_labels: int,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
) -> torch.Tensor:
    """Generate designs from labels using inverse pass.

    Args:
        model: Flow matching model.
        augmented_labels: Augmented labels (N, L + P).
        num_labels: Number of labels L.
        conditioning: Optional conditioning (N, C) for conditional models.
        steps: Number of integration steps.

    Returns:
        Generated designs (N, P).
    """
    if conditioning is not None:
        result = euler_method_with_conditioning(
            model=model,
            input=augmented_labels,
            conditioning=conditioning,
            start_t=1,
            end_t=0,
            steps=steps,
        )
    else:
        result = euler_method(
            model=model, input=augmented_labels, start_t=1, end_t=0, steps=steps
        )

    # Extract design parameters (remove label dimensions)
    return result[:, num_labels:]


def create_uq_results_dict(
    ensemble_variance: np.ndarray,
    fm_losses: np.ndarray,
    error_metric: np.ndarray,
    error_metric_name: str,
    ood_mask: np.ndarray,
    nb_samples: int,
    checkpoint_names: List[str],
    ensemble_num_params: List[int],
    extra_fields: Optional[Dict] = None,
) -> Dict:
    """Create standardized UQ results dictionary.

    Args:
        ensemble_variance: Ensemble variance values (N,).
        fm_losses: Flow matching loss values (N,).
        error_metric: Error metric values (N,) - e.g., surrogate_error or self_consistency_error.
        error_metric_name: Name of the error metric for the results dict.
        ood_mask: Binary mask (N,) - 1 for OOD, 0 for in-distribution.
        nb_samples: Number of samples per category.
        checkpoint_names: List of checkpoint names.
        ensemble_num_params: List of parameter counts per model.
        extra_fields: Optional extra fields to include.

    Returns:
        Results dictionary with statistics.
    """
    results = {
        "nb_samples_per_category": nb_samples,
        "categories": ["in_distribution", "ood"],
        "checkpoint_names": checkpoint_names,
        "num_parameters": ensemble_num_params,
        "ood_mask": ood_mask.tolist(),
        "ensemble_variance": {
            "mean": float(np.mean(ensemble_variance)),
            "std": float(np.std(ensemble_variance)),
            "min": float(np.min(ensemble_variance)),
            "max": float(np.max(ensemble_variance)),
        },
        "fm_loss": {
            "mean": float(np.mean(fm_losses)),
            "std": float(np.std(fm_losses)),
            "min": float(np.min(fm_losses)),
            "max": float(np.max(fm_losses)),
        },
        error_metric_name: {
            "mean": float(np.mean(error_metric)),
            "std": float(np.std(error_metric)),
            "min": float(np.min(error_metric)),
            "max": float(np.max(error_metric)),
        },
    }

    if extra_fields:
        results.update(extra_fields)

    return results


def save_uq_results(
    results: Dict,
    ensemble_variance: np.ndarray,
    fm_losses: np.ndarray,
    error_metric: np.ndarray,
    error_metric_name: str,
    ood_mask: np.ndarray,
    output_dir: Path,
    filename_prefix: str,
    epochs: int,
) -> Tuple[Path, Path]:
    """Save UQ results to JSON and NPZ files.

    Args:
        results: Results dictionary.
        ensemble_variance: Ensemble variance values.
        fm_losses: Flow matching loss values.
        error_metric: Error metric values.
        error_metric_name: Name of the error metric.
        ood_mask: Binary OOD mask.
        output_dir: Output directory.
        filename_prefix: Prefix for output files (e.g., "uq_ood_gas_turbine").
        epochs: Number of training epochs (for filename).

    Returns:
        Tuple of (json_path, npz_path).
    """
    output_dir.mkdir(exist_ok=True, parents=True)

    # Save JSON
    json_path = output_dir / f"{filename_prefix}_results_ep{epochs}.json"
    with open(json_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Results saved to: {json_path}")

    # Save NPZ
    npz_path = output_dir / f"{filename_prefix}_data_ep{epochs}.npz"
    np.savez(
        npz_path,
        ensemble_variance=ensemble_variance,
        fm_loss=fm_losses,
        **{error_metric_name: error_metric},
        ood_mask=ood_mask,
    )
    print(f"Data saved to: {npz_path}")

    return json_path, npz_path


def print_uq_summary(results: Dict, error_metric_name: str = "surrogate_error"):
    """Print UQ evaluation summary.

    Args:
        results: Results dictionary.
        error_metric_name: Name of the error metric to print.
    """
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    for key, val in results.items():
        if isinstance(val, dict):
            print(f"\n{key}:")
            for k, v in val.items():
                print(f"  {k}: {v}")
        else:
            print(f"{key}: {val}")


def evaluate_uq_with_ood_generic(
    models: List[torch.nn.Module],
    run_infos: List[Dict],
    checkpoint_names: List[str],
    get_validation_labels: Callable[[torch.device], torch.Tensor],
    get_training_labels: Callable[[torch.device], torch.Tensor],
    compute_error_metric: Callable[
        [torch.Tensor, torch.Tensor, torch.device], torch.Tensor
    ],
    error_metric_name: str,
    num_design_params: int,
    num_labels: int,
    get_conditioning: Optional[
        Callable[[torch.Tensor, torch.Tensor, torch.device], torch.Tensor]
    ] = None,
    nb_samples: int = 2500,
    ood_threshold: float = 0.05,
    ood_grid_steps: int = 25,
    euler_steps: int = 30,
    return_intermediate_data: bool = False,
    ood_difficulty: Optional[str] = None,
    seed: int = 42,
):
    """Generic UQ evaluation with OOD detection.

    This function implements the common UQ evaluation pipeline:
    1. Get in-distribution labels from validation set
    2. Generate OOD labels
    3. Generate designs from labels (inverse pass)
    4. Compute ensemble variance
    5. Compute error metric (surrogate error or self-consistency)
    6. Compute gradient magnitudes and FM losses

    Models must be loaded using functions from ensembles.py (e.g.,
    load_unifoil_diag_cfm_ensemble, load_gas_turbine_diag_cfm_ensemble).

    Args:
        models: Pre-loaded ensemble models from ensembles.py.
        run_infos: Run info dictionaries for each model.
        checkpoint_names: Checkpoint names for each model.
        get_validation_labels: Function to get validation labels.
            Signature: (device) -> Tensor of shape (N_val, L)
        get_training_labels: Function to get training labels for OOD generation.
            Signature: (device) -> Tensor of shape (N_train, L)
        compute_error_metric: Function to compute error metric.
            Signature: (designs, labels, device) -> Tensor of shape (N,)
        error_metric_name: Name of the error metric (for results dict).
        num_design_params: Design space dimension P.
        num_labels: Label space dimension L.
        get_conditioning: Optional function to get conditioning for in-dist and OOD.
            Signature: (in_dist_labels, ood_labels, device) -> Tensor of shape (N, C)
            If None, no conditioning is used.
        nb_samples: Number of samples per category.
        ood_threshold: Threshold for OOD point generation (used if ood_difficulty=None).
        ood_grid_steps: Grid steps for OOD point generation.
        euler_steps: Number of Euler integration steps.
        return_intermediate_data: If True, also return intermediate data dict
            containing (reference_model, simulated_designs, all_labels, conditioning).
        ood_difficulty: OOD difficulty level ("easy", "medium", "hard").
            If specified, uses difficulty-based OOD generation with calibrated thresholds.
            If None, uses threshold-based generation with ood_threshold.
        seed: Random seed for reproducible sampling of in-dist and OOD points (default: 42).

    Returns:
        Tuple containing:
        - results: Summary statistics dictionary
        - ensemble_variance: Array of ensemble variance values
        - fm_losses: Array of flow matching loss values
        - error_metric: Array of error metric values
        - ood_mask: Binary mask (1 = OOD, 0 = in-distribution)
        - intermediate_data: (only if return_intermediate_data=True) Dict with
            reference_model, simulated_designs, all_labels, conditioning
    """
    device = get_device()
    print(f"Using device: {device}")

    # Use pre-loaded models (must be loaded via ensembles.py functions)
    print(f"Using {len(models)} pre-loaded ensemble models")
    reference_model = models[0]
    ensemble_num_params = [sum(p.numel() for p in m.parameters()) for m in models]

    # Get validation labels (in-distribution)
    print("Loading validation data...")
    val_labels = get_validation_labels(device)
    print(f"Validation labels shape: {val_labels.shape}")

    # Get training labels for OOD generation
    print("Loading training data for OOD generation...")
    train_labels = get_training_labels(device)
    print(f"Training labels shape (for OOD): {train_labels.shape}")

    # Generate OOD points
    if ood_difficulty is not None:
        print(f"Generating OOD points (difficulty={ood_difficulty})...")
        ood_points, ood_dists, ood_config = get_ood_points_by_difficulty(
            train_labels.cpu(),
            difficulty=ood_difficulty,
            n_points=nb_samples,
            grid_steps=ood_grid_steps,
        )
        print(f"Generated {ood_points.shape[0]} OOD points "
              f"(dist in [{ood_config['min_dist']:.3f}, {ood_config['max_dist']:.3f}])")
    else:
        print("Generating OOD points...")
        ood_points, ood_dists = get_ood_points(
            train_labels.cpu(), threshold=ood_threshold, grid_steps=ood_grid_steps
        )
        print(f"Generated {ood_points.shape[0]} OOD points")

    # Sample points with fixed seed for reproducibility
    in_dist_labels, ood_sampled, ood_mask = sample_in_dist_and_ood(
        val_labels=val_labels.cpu(),
        ood_points=ood_points,
        nb_samples=nb_samples,
        device=device,
        seed=seed,
    )

    # Concatenate all points
    all_labels = torch.cat([in_dist_labels, ood_sampled], dim=0)

    # Get conditioning if needed
    conditioning = None
    if get_conditioning is not None:
        conditioning = get_conditioning(in_dist_labels, ood_sampled, device)
        print(f"Conditioning shape: {conditioning.shape}")

    # Augment labels with noise complement
    augmented_labels = torch.cat(
        [
            all_labels,
            create_y_complement(
                all_labels.shape[0], num_design_params, num_labels, device
            ),
        ],
        dim=1,
    )

    # Generate designs from labels (inverse pass)
    print("\nGenerating designs from labels...")
    simulated_designs = generate_designs_from_labels(
        model=reference_model,
        augmented_labels=augmented_labels,
        num_labels=num_labels,
        conditioning=conditioning,
        steps=euler_steps,
    )
    print(f"Simulated designs shape: {simulated_designs.shape}")

    # Compute ensemble variance
    print("Computing ensemble predictions...")
    ensemble_variance = compute_ensemble_variance(
        simulated_designs=simulated_designs,
        models=models,
        num_labels=num_labels,
        device=device,
        conditioning=conditioning,
        steps=euler_steps,
    )

    # Compute error metric
    print(f"Computing {error_metric_name}...")
    error_metric = compute_error_metric(simulated_designs, all_labels, device)

    # Compute FM losses
    print("Computing FM losses...")
    fm_losses = compute_fm_losses(
        simulated_designs=simulated_designs,
        augmented_labels=augmented_labels,
        model=reference_model,
        num_labels=num_labels,
        device=device,
        conditioning=conditioning,
    )

    # Convert to numpy
    ensemble_variance_np = ensemble_variance.detach().cpu().numpy()
    fm_losses_np = fm_losses.detach().numpy()
    error_metric_np = error_metric.detach().cpu().numpy()

    # Create results dictionary
    results = create_uq_results_dict(
        ensemble_variance=ensemble_variance_np,
        fm_losses=fm_losses_np,
        error_metric=error_metric_np,
        error_metric_name=error_metric_name,
        ood_mask=ood_mask,
        nb_samples=nb_samples,
        checkpoint_names=checkpoint_names,
        ensemble_num_params=ensemble_num_params,
    )

    if return_intermediate_data:
        intermediate_data = {
            "reference_model": reference_model,
            "simulated_designs": simulated_designs,
            "all_labels": all_labels,
            "conditioning": conditioning,
        }
        return (
            results,
            ensemble_variance_np,
            fm_losses_np,
            error_metric_np,
            ood_mask,
            intermediate_data,
        )

    return (
        results,
        ensemble_variance_np,
        fm_losses_np,
        error_metric_np,
        ood_mask,
    )


# =============================================================================
# Visualization Functions
# =============================================================================


def normalize(data):
    """Normalize data to [0, 1] range."""
    return (data - data.min()) / (data.max() - data.min() + 1e-10)


def create_violin_plots(data_file, output_file):
    """Create violin plots showing UQ measures for in-dist vs OOD samples."""
    # Load data
    data = np.load(data_file)
    ensemble_variance = data["ensemble_variance"]
    fm_loss = data["fm_loss"]
    ood_mask = data["ood_mask"]

    # Normalize
    ensemble_variance = normalize(ensemble_variance)
    fm_loss = normalize(fm_loss)

    # Split by OOD mask
    in_dist_mask = ood_mask == 0
    ood_detected_mask = ood_mask == 1

    # Create figure with 2 subplots
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    colors = ["green", "red"]

    # Ensemble Variance
    ax = axes[0]
    data_to_plot = [
        ensemble_variance[in_dist_mask],
        ensemble_variance[ood_detected_mask],
    ]
    parts = ax.violinplot(
        data_to_plot, positions=[1, 2], showmeans=True, showmedians=True
    )

    for i, pc in enumerate(parts["bodies"]):
        pc.set_facecolor(colors[i])
        pc.set_alpha(0.7)

    ax.set_ylabel("Normalized Ensemble Variance", fontsize=11)
    ax.set_xlabel("Sample Type", fontsize=11)
    ax.set_xticks([1, 2])
    ax.set_xticklabels(["In-Distribution", "OOD"])
    ax.set_title("Ensemble Variance", fontsize=12, fontweight="bold")
    ax.grid(True, alpha=0.3, axis="y")

    # FM Loss
    ax = axes[1]
    data_to_plot = [fm_loss[in_dist_mask], fm_loss[ood_detected_mask]]
    parts = ax.violinplot(
        data_to_plot, positions=[1, 2], showmeans=True, showmedians=True
    )

    for i, pc in enumerate(parts["bodies"]):
        pc.set_facecolor(colors[i])
        pc.set_alpha(0.7)

    ax.set_ylabel("Normalized FM Loss", fontsize=11)
    ax.set_xlabel("Sample Type", fontsize=11)
    ax.set_xticks([1, 2])
    ax.set_xticklabels(["In-Distribution", "OOD"])
    ax.set_title("FM Loss", fontsize=12, fontweight="bold")
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()

    # Save
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"Violin plots saved to: {output_file}")

    plt.close()


def create_roc_curves(data_file, output_file):
    """Create ROC curves for OOD detection."""
    # Load data
    data = np.load(data_file)
    fm_loss = data["fm_loss"]
    ensemble_variance = data["ensemble_variance"]
    ood_mask = data["ood_mask"]

    # Compute ROC curves
    fpr_fm, tpr_fm, _ = roc_curve(ood_mask, fm_loss)
    auc_fm = auc(fpr_fm, tpr_fm)

    fpr_ens, tpr_ens, _ = roc_curve(ood_mask, ensemble_variance)
    auc_ens = auc(fpr_ens, tpr_ens)

    # Create plot
    plt.figure(figsize=(6, 5))

    # Plot ROC curves
    plt.plot(fpr_fm, tpr_fm, "g--", linewidth=2, label=f"FM Loss (AUC = {auc_fm:.3f})")
    plt.plot(
        fpr_ens,
        tpr_ens,
        "r:",
        linewidth=2,
        label=f"Ensemble Variance (AUC = {auc_ens:.3f})",
    )

    # Diagonal reference line
    plt.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.3, label="Random")

    plt.xlabel("False Positive Rate", fontsize=12)
    plt.ylabel("True Positive Rate", fontsize=12)
    plt.title("ROC Curves for OOD Detection", fontsize=13, fontweight="bold")
    plt.legend(loc="lower right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    # Save
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    print(f"ROC curve saved to: {output_file}")

    plt.close()

    return {"auc_fm_loss": auc_fm, "auc_ensemble": auc_ens}


def create_point_cloud_plot(
    in_dist_labels: np.ndarray,
    ood_labels: np.ndarray,
    output_file: Path,
    dataset_name: str = "Dataset",
    train_labels: Optional[np.ndarray] = None,
) -> None:
    """Create point cloud visualization of in-distribution vs OOD labels.

    Args:
        in_dist_labels: In-distribution labels (N_in, L).
        ood_labels: OOD labels (N_ood, L).
        output_file: Path to save the figure.
        dataset_name: Name of the dataset for the title.
        train_labels: Optional training labels for reference (N_train, L).
    """
    L = in_dist_labels.shape[1]
    np.random.seed(42)

    # Subsample for clarity
    n_plot = min(1000, len(in_dist_labels))
    idx_in = np.random.choice(len(in_dist_labels), n_plot, replace=False)
    idx_ood = np.random.choice(len(ood_labels), min(n_plot, len(ood_labels)), replace=False)

    if L == 3:
        fig = plt.figure(figsize=(12, 5))

        # Plot 1: In-distribution vs OOD
        ax1 = fig.add_subplot(121, projection='3d')
        ax1.scatter(in_dist_labels[idx_in, 0], in_dist_labels[idx_in, 1], in_dist_labels[idx_in, 2],
                    c='blue', alpha=0.5, s=10, label='In-distribution')
        ax1.scatter(ood_labels[idx_ood, 0], ood_labels[idx_ood, 1], ood_labels[idx_ood, 2],
                    c='red', alpha=0.5, s=10, label='OOD')
        ax1.set_xlabel('Label 1')
        ax1.set_ylabel('Label 2')
        ax1.set_zlabel('Label 3')
        ax1.set_title(f'{dataset_name}: In-dist vs OOD Labels')
        ax1.legend()

        # Plot 2: With training data if provided
        ax2 = fig.add_subplot(122, projection='3d')
        if train_labels is not None:
            idx_train = np.random.choice(len(train_labels), min(2000, len(train_labels)), replace=False)
            ax2.scatter(train_labels[idx_train, 0], train_labels[idx_train, 1], train_labels[idx_train, 2],
                        c='green', alpha=0.3, s=5, label='Training')
        ax2.scatter(ood_labels[idx_ood, 0], ood_labels[idx_ood, 1], ood_labels[idx_ood, 2],
                    c='red', alpha=0.5, s=10, label='OOD')
        ax2.set_xlabel('Label 1')
        ax2.set_ylabel('Label 2')
        ax2.set_zlabel('Label 3')
        ax2.set_title(f'{dataset_name}: Training vs OOD Labels')
        ax2.legend()

    elif L == 2:
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))

        # Plot 1: In-distribution vs OOD
        axes[0].scatter(in_dist_labels[idx_in, 0], in_dist_labels[idx_in, 1],
                        c='blue', alpha=0.5, s=10, label='In-distribution')
        axes[0].scatter(ood_labels[idx_ood, 0], ood_labels[idx_ood, 1],
                        c='red', alpha=0.5, s=10, label='OOD')
        axes[0].set_xlabel('Label 1')
        axes[0].set_ylabel('Label 2')
        axes[0].set_title(f'{dataset_name}: In-dist vs OOD Labels')
        axes[0].legend()

        # Plot 2: With training data if provided
        if train_labels is not None:
            idx_train = np.random.choice(len(train_labels), min(2000, len(train_labels)), replace=False)
            axes[1].scatter(train_labels[idx_train, 0], train_labels[idx_train, 1],
                            c='green', alpha=0.3, s=5, label='Training')
        axes[1].scatter(ood_labels[idx_ood, 0], ood_labels[idx_ood, 1],
                        c='red', alpha=0.5, s=10, label='OOD')
        axes[1].set_xlabel('Label 1')
        axes[1].set_ylabel('Label 2')
        axes[1].set_title(f'{dataset_name}: Training vs OOD Labels')
        axes[1].legend()
    else:
        # For higher dimensions, show 2D projections
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
        projections = [(0, 1), (0, 2), (1, 2)] if L > 2 else [(0, 1)]

        for ax, (i, j) in zip(axes, projections[:min(3, L-1)]):
            ax.scatter(in_dist_labels[idx_in, i], in_dist_labels[idx_in, j],
                       c='blue', alpha=0.5, s=10, label='In-distribution')
            ax.scatter(ood_labels[idx_ood, i], ood_labels[idx_ood, j],
                       c='red', alpha=0.5, s=10, label='OOD')
            ax.set_xlabel(f'Label {i+1}')
            ax.set_ylabel(f'Label {j+1}')
            ax.legend()
            ax.set_title(f'Label {i+1} vs {j+1}')

        plt.suptitle(f'{dataset_name}: 2D Projections of Labels', y=1.02)

    plt.tight_layout()
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"Point cloud plot saved to: {output_file}")
    plt.close()


def create_all_plots(
    metrics: Dict[str, np.ndarray],
    ood_mask: np.ndarray,
    in_dist_labels: np.ndarray,
    ood_labels: np.ndarray,
    output_dir: Path,
    dataset_name: str,
    train_labels: Optional[np.ndarray] = None,
) -> Dict[str, float]:
    """Create all standard plots for UQ evaluation.

    Args:
        metrics: Dictionary mapping metric names to arrays of values.
            Expected keys: 'ensemble_variance', 'fm_loss',
            plus any dataset-specific metrics like 'zero_deviation', 'self_consistency'.
        ood_mask: Binary mask (N,) - 1 for OOD, 0 for in-distribution.
        in_dist_labels: In-distribution labels (N_in, L).
        ood_labels: OOD labels (N_ood, L).
        output_dir: Directory to save plots.
        dataset_name: Name of the dataset for titles.
        train_labels: Optional training labels for reference.

    Returns:
        Dictionary of AUC scores for each metric.
    """
    output_dir.mkdir(exist_ok=True, parents=True)

    # 1. Point cloud plot
    point_cloud_file = output_dir / f"{dataset_name.lower()}_ood_point_clouds.png"
    create_point_cloud_plot(
        in_dist_labels=in_dist_labels,
        ood_labels=ood_labels,
        output_file=point_cloud_file,
        dataset_name=dataset_name,
        train_labels=train_labels,
    )

    # Get ordered list of metrics (use canonical order, skip metrics not in dict)
    ordered_metrics = [
        (name, metrics[name]) for name in UQ_METRIC_ORDER if name in metrics
    ]
    # Add any metrics not in canonical order at the end
    for name, values in metrics.items():
        if name not in UQ_METRIC_ORDER:
            ordered_metrics.append((name, values))

    # 2. Compute AUC scores
    auc_scores = {}
    for name, values in ordered_metrics:
        try:
            auc_score = roc_auc_score(ood_mask, values)
            auc_scores[name] = auc_score
        except Exception as e:
            print(f"  Could not compute AUC for {name}: {e}")
            auc_scores[name] = float('nan')

    # 3. Create violin plots for all metrics (in canonical order)
    fig, axes = plt.subplots(1, len(ordered_metrics), figsize=(4 * len(ordered_metrics), 4))
    if len(ordered_metrics) == 1:
        axes = [axes]

    in_dist_mask = ood_mask == 0
    ood_detected_mask = ood_mask == 1
    colors = ["green", "red"]

    for ax, (name, values) in zip(axes, ordered_metrics):
        # Normalize for visualization
        norm_values = normalize(values)
        data_to_plot = [norm_values[in_dist_mask], norm_values[ood_detected_mask]]

        parts = ax.violinplot(data_to_plot, positions=[1, 2], showmeans=True, showmedians=True)
        for i, pc in enumerate(parts["bodies"]):
            pc.set_facecolor(colors[i])
            pc.set_alpha(0.7)

        ax.set_ylabel(f"Normalized {name.replace('_', ' ').title()}", fontsize=10)
        ax.set_xticks([1, 2])
        ax.set_xticklabels(["In-Dist", "OOD"])
        auc_val = auc_scores.get(name, float('nan'))
        ax.set_title(f"{name.replace('_', ' ').title()}\nAUC={auc_val:.3f}", fontsize=11)
        ax.grid(True, alpha=0.3, axis="y")

    plt.suptitle(f"{dataset_name} UQ Metrics: In-Distribution vs OOD", fontsize=12, fontweight="bold")
    plt.tight_layout()
    violin_file = output_dir / f"{dataset_name.lower()}_ood_violin_plots.pdf"
    plt.savefig(violin_file, dpi=150, bbox_inches='tight')
    print(f"Violin plots saved to: {violin_file}")
    plt.close()

    # 4. Create ROC curves for all metrics (in canonical order)
    plt.figure(figsize=(8, 6))
    line_styles = ['-', '--', ':', '-.', '-', '--']
    colors_roc = ['blue', 'green', 'red', 'purple', 'orange', 'brown']

    for i, (name, values) in enumerate(ordered_metrics):
        try:
            fpr, tpr, _ = roc_curve(ood_mask, values)
            auc_val = auc(fpr, tpr)
            plt.plot(fpr, tpr, line_styles[i % len(line_styles)],
                     color=colors_roc[i % len(colors_roc)], linewidth=2,
                     label=f"{name.replace('_', ' ').title()} (AUC={auc_val:.3f})")
        except Exception:
            pass

    plt.plot([0, 1], [0, 1], 'k--', linewidth=1, alpha=0.3, label="Random")
    plt.xlabel("False Positive Rate", fontsize=12)
    plt.ylabel("True Positive Rate", fontsize=12)
    plt.title(f"{dataset_name}: ROC Curves for OOD Detection", fontsize=13, fontweight="bold")
    plt.legend(loc="lower right", fontsize=9)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    roc_file = output_dir / f"{dataset_name.lower()}_ood_roc_curves.pdf"
    plt.savefig(roc_file, dpi=150, bbox_inches='tight')
    print(f"ROC curves saved to: {roc_file}")
    plt.close()

    return auc_scores


def print_auc_summary(auc_scores: Dict[str, float]) -> None:
    """Print a summary table of AUC scores."""
    print("\n" + "=" * 50)
    print("OOD DETECTION AUC SCORES")
    print("=" * 50)
    for name, score in auc_scores.items():
        print(f"  {name.replace('_', ' ').title():25s}: AUC = {score:.4f}")
