"""
Evaluation Functions for Bayesian Neural Networks

This module implements high-level evaluation orchestration for:
- Single model evaluation with STD-based uncertainty
- Multi-model comparison with STD-based uncertainty
- MI-based evaluation (v2)
- Rank sweep experiments
"""

import gc
import numpy as np
import pandas as pd
import tensorflow as tf
from typing import Dict

from modules.model_builders import build_lowrank_gauss
from modules.inference import (
    mc_predictions,
    mc_predictions_with_mi_v2,
    ensemble_predictions,
)
from modules.metrics import (
    compute_nll,
    compute_auroc,
    compute_aupr_success,
    compute_aupr_error,
    compute_aupr_in_domain,
    compute_aupr_ood,
    find_best_ece,
    compute_aupr_success_mi,
    compute_aupr_error_mi,
    compute_auroc_ood_mi,
    compute_aupr_in_domain_mi,
    compute_aupr_ood_mi,
)
def evaluate_all_metrics_silent(model,
                                X_test, y_test,
                                X_ood, y_ood,
                                model_name: str,
                                n_samples: int = 256,
                                seed: int = 42) -> Dict[str, float]:
    """Same as evaluate_all_metrics_with_mi_v2 but without prints."""
    tf.keras.backend.clear_session()
    gc.collect()
    tf.random.set_seed(seed)
    np.random.seed(seed)

    pred_test, std_test, mi_test = mc_predictions_with_mi_v2(
        model, X_test, n_samples, seed=seed
    )
    pred_ood, std_ood, mi_ood = mc_predictions_with_mi_v2(
        model, X_ood, n_samples, seed=seed
    )

    metrics = {}
    metrics['AUROC'] = compute_auroc(y_test, pred_test)

    # MI-based
    metrics['AUPR_Success_MI']    = compute_aupr_success_mi(y_test, pred_test, mi_test)
    metrics['AUPR_Error_MI']      = compute_aupr_error_mi(y_test, pred_test, mi_test)
    metrics['AUROC_OOD_MI']       = compute_auroc_ood_mi(mi_test, mi_ood)
    metrics['AUPR_In_Domain_MI']  = compute_aupr_in_domain_mi(mi_test, mi_ood)
    metrics['AUPR_OOD_MI']        = compute_aupr_ood_mi(mi_test, mi_ood)

    # STD-based
    metrics['AUPR_Success_STD']   = compute_aupr_success_mi(y_test, pred_test, std_test)
    metrics['AUPR_Error_STD']     = compute_aupr_error_mi(y_test, pred_test, std_test)
    metrics['AUROC_OOD_STD']      = compute_auroc_ood_mi(std_test, std_ood)
    metrics['AUPR_In_Domain_STD'] = compute_aupr_in_domain_mi(std_test, std_ood)
    metrics['AUPR_OOD_STD']       = compute_aupr_ood_mi(std_test, std_ood)

    ece_results = find_best_ece(y_test, pred_test)
    metrics['ECE_best']        = ece_results['best_ece']
    metrics['ECE_best_config'] = ece_results['best_config']

    metrics['Mean_MI_In']    = mi_test.mean()
    metrics['Mean_MI_OOD']   = mi_ood.mean()
    metrics['MI_Ratio']      = mi_ood.mean() / (mi_test.mean() + 1e-10)
    metrics['Mean_STD_In']   = std_test.mean()
    metrics['Mean_STD_OOD']  = std_ood.mean()
    metrics['STD_Ratio']     = std_ood.mean() / (std_test.mean() + 1e-10)

    return metrics


def evaluate_all_metrics(model: tf.keras.Model,
                         X_test: np.ndarray,
                         y_test: np.ndarray,
                         X_ood: np.ndarray,
                         y_ood: np.ndarray,
                         model_name: str,
                         n_samples: int = 512,
                         seed: int = 42) -> Dict[str, float]:
    """
    Compute all metrics for a single model using STD-based uncertainty.

    Parameters:
    -----------
    model : trained deterministic or Bayesian neural network
    X_test : test set features (in-domain)
    y_test : test set labels (in-domain)
    X_ood : out-of-domain features (e.g., newborns)
    y_ood : out-of-domain labels
    model_name : name of the model
    n_samples : number of MC samples (default 512)
    seed : random seed (default 42)

    Returns:
    --------
    metrics : dict with all computed metrics
    """
    print(f"\n{'='*80}")
    print(f"Evaluating: {model_name}")
    print(f"{'='*80}")
    tf.keras.backend.clear_session()
    gc.collect()
    # Reset seeds
    tf.random.set_seed(seed)
    np.random.seed(seed)

    # ---- Get predictions and uncertainties for in-domain (test) data ----
    print("Computing in-domain predictions...")
    pred_test, unc_test = mc_predictions(model, X_test, n_samples)

    # ---- Get predictions and uncertainties for OOD data ----
    print("Computing OOD predictions...")
    pred_ood, unc_ood = mc_predictions(model, X_ood, n_samples)

    # ---- Compute all metrics ----
    print("Computing metrics...")

    metrics = {}

    # 1. NLL (Negative Log Likelihood)
    metrics['NLL'] = compute_nll(y_test, pred_test)
    print(f"  NLL: {metrics['NLL']:.4f}")

    # 2. AUROC (classification performance)
    metrics['AUROC'] = compute_auroc(y_test, pred_test)
    print(f"  AUROC: {metrics['AUROC']:.4f}")

    # 3. AUPR-Success (uncertainty identifies correct predictions)
    metrics['AUPR_Success'] = compute_aupr_success(y_test, pred_test, unc_test)
    print(f"  AUPR-Success: {metrics['AUPR_Success']:.4f}")

    # 4. AUPR-Error (uncertainty identifies incorrect predictions)
    metrics['AUPR_Error'] = compute_aupr_error(y_test, pred_test, unc_test)
    print(f"  AUPR-Error: {metrics['AUPR_Error']:.4f}")

    # 5. AUPR-In-Domain (uncertainty identifies in-domain samples)
    metrics['AUPR_In_Domain'] = compute_aupr_in_domain(unc_test, unc_ood)
    print(f"  AUPR-In-Domain: {metrics['AUPR_In_Domain']:.4f}")

    # 6. AUPR-Out-of-Domain (uncertainty identifies OOD samples)
    metrics['AUPR_OOD'] = compute_aupr_ood(unc_test, unc_ood)
    print(f"  AUPR-OOD: {metrics['AUPR_OOD']:.4f}")

    # 7. ECE (calibration error) - find best configuration
    ece_results = find_best_ece(y_test, pred_test)
    metrics['ECE_best'] = ece_results['best_ece']
    metrics['ECE_best_config'] = ece_results['best_config']
    # Also store individual ECE configurations
    for config_name, ece_val in ece_results.items():
        if config_name not in ['best_config', 'best_ece']:
            metrics[f'ECE_{config_name}'] = ece_val
    print(f"  ECE (best): {metrics['ECE_best']:.4f} [{metrics['ECE_best_config']}]")

    # Additional statistics
    metrics['Mean_Uncertainty_In'] = unc_test.mean()
    metrics['Mean_Uncertainty_OOD'] = unc_ood.mean()
    metrics['Uncertainty_Ratio'] = unc_ood.mean() / unc_test.mean()

    print(f"  Mean Uncertainty (In-Domain): {metrics['Mean_Uncertainty_In']:.4f}")
    print(f"  Mean Uncertainty (OOD): {metrics['Mean_Uncertainty_OOD']:.4f}")
    print(f"  Uncertainty Ratio (OOD/In): {metrics['Uncertainty_Ratio']:.2f}x")

    return metrics


def evaluate_all_models(models_dict: Dict,
                        X_test,
                        y_test,
                        X_ood,
                        y_ood,
                        n_samples: int = 512,
                        seed: int = 42) -> pd.DataFrame:
    """
    Evaluate all models (deterministic, Bayesian, and ensemble) and return results as a DataFrame.

    Parameters:
    -----------
    models_dict : dict of {model_name: model or list_of_models}
                  For deterministic/Bayesian models, value is a single model
                  For ensemble, value is a list of models
    X_test : test set features (DataFrame or numpy array)
    y_test : test set labels (Series or numpy array)
    X_ood : OOD features (DataFrame or numpy array)
    y_ood : OOD labels (Series or numpy array)
    n_samples : number of MC samples for Bayesian models
    seed : random seed

    Returns:
    --------
    results_df : pandas DataFrame with all metrics for all models
    """
    results = {}

    for model_name, model in models_dict.items():
        tf.keras.backend.clear_session()
        gc.collect()
        # Reset seeds
        tf.random.set_seed(seed)
        np.random.seed(seed)

        # Check if this is an ensemble (list of models) or single model
        if isinstance(model, list):
            # This is a deep ensemble
            print(f"\n{'='*80}")
            print(f"Evaluating: {model_name}")
            print(f"{'='*80}")

            # Get predictions using ensemble
            pred_test, unc_test = ensemble_predictions(model, X_test)
            pred_ood, unc_ood = ensemble_predictions(model, X_ood)

            # Convert to numpy if needed
            y_test_np = y_test.values if hasattr(y_test, 'values') else y_test
            y_ood_np = y_ood.values if hasattr(y_ood, 'values') else y_ood

            # Compute metrics
            metrics = {}
            metrics['NLL'] = compute_nll(y_test_np, pred_test)
            metrics['AUROC'] = compute_auroc(y_test_np, pred_test)
            metrics['AUPR_Success'] = compute_aupr_success(y_test_np, pred_test, unc_test)
            metrics['AUPR_Error'] = compute_aupr_error(y_test_np, pred_test, unc_test)
            metrics['AUPR_In_Domain'] = compute_aupr_in_domain(unc_test, unc_ood)
            metrics['AUPR_OOD'] = compute_aupr_ood(unc_test, unc_ood)

            # ECE - find best configuration
            ece_results = find_best_ece(y_test_np, pred_test)
            metrics['ECE_best'] = ece_results['best_ece']
            metrics['ECE_best_config'] = ece_results['best_config']
            for config_name, ece_val in ece_results.items():
                if config_name not in ['best_config', 'best_ece']:
                    metrics[f'ECE_{config_name}'] = ece_val

            metrics['Mean_Uncertainty_In'] = unc_test.mean()
            metrics['Mean_Uncertainty_OOD'] = unc_ood.mean()
            metrics['Uncertainty_Ratio'] = unc_ood.mean() / unc_test.mean()

            print(f"  NLL: {metrics['NLL']:.4f}")
            print(f"  AUROC: {metrics['AUROC']:.4f}")
            print(f"  ECE (best): {metrics['ECE_best']:.4f} [{metrics['ECE_best_config']}]")

            results[model_name] = metrics
        else:
            # This is a single deterministic/Bayesian model
            # Convert to numpy arrays
            X_test_np = X_test.values if hasattr(X_test, 'values') else X_test
            y_test_np = y_test.values if hasattr(y_test, 'values') else y_test
            X_ood_np = X_ood.values if hasattr(X_ood, 'values') else X_ood
            y_ood_np = y_ood.values if hasattr(y_ood, 'values') else y_ood

            metrics = evaluate_all_metrics(
                model, X_test_np, y_test_np, X_ood_np, y_ood_np,
                model_name, n_samples, seed
            )
            results[model_name] = metrics

    # Convert to DataFrame
    df = pd.DataFrame(results).T

    # Reorder columns for readability
    core_columns = [
        'NLL',
        'AUROC',
        'AUPR_Success',
        'AUPR_Error',
        'AUPR_In_Domain',
        'AUPR_OOD',
        'ECE_best',
        'ECE_best_config',
        'Mean_Uncertainty_In',
        'Mean_Uncertainty_OOD',
        'Uncertainty_Ratio'
    ]

    # Add any additional ECE columns
    ece_columns = [col for col in df.columns if col.startswith('ECE_') and col not in core_columns]
    column_order = core_columns + ece_columns

    df = df[column_order]

    return df


def evaluate_all_metrics_with_mi_v2(model,
                                     X_test: np.ndarray,
                                     y_test: np.ndarray,
                                     X_ood: np.ndarray,
                                     y_ood: np.ndarray,
                                     model_name: str,
                                     n_samples: int = 512,
                                     seed: int = 42) -> Dict[str, float]:
    """
    Evaluate all metrics using proper information-theoretic MI as epistemic uncertainty.
    Returns both MI-based and STD-based metrics for comparison.

    Supports both Bayesian models and Deep Ensembles.

    Parameters:
    -----------
    model : trained Bayesian neural network OR list of models (Deep Ensemble)
    X_test : test set features (in-domain)
    y_test : test set labels (in-domain)
    X_ood : out-of-domain features
    y_ood : out-of-domain labels
    model_name : name of the model
    n_samples : number of MC samples (default 512)
    seed : random seed (default 42)

    Returns:
    --------
    metrics : dict with all computed metrics (MI and STD based)
    """
    print(f"\n{'='*80}")
    print(f"Evaluating with MI: {model_name}")
    print(f"{'='*80}")

    # ========== Get predictions with proper MI ==========
    print("Computing in-domain predictions with MI...")
    tf.keras.backend.clear_session()
    gc.collect()
    tf.random.set_seed(seed)
    np.random.seed(seed)

    pred_test, std_test, mi_test = mc_predictions_with_mi_v2(
        model, X_test, n_samples, seed=seed
    )

    print("Computing OOD predictions with MI...")
    pred_ood, std_ood, mi_ood = mc_predictions_with_mi_v2(
        model, X_ood, n_samples, seed=seed
    )

    # ========== Compute all metrics ==========
    print("Computing metrics...")

    metrics = {}

    # Classification performance
    metrics['AUROC'] = compute_auroc(y_test, pred_test)
    print(f"  AUROC: {metrics['AUROC']:.4f}")

    # NLL (uses same pred_test from line 338)
    metrics['NLL'] = compute_nll(y_test, pred_test)
    print(f"  NLL: {metrics['NLL']:.4f}")

    # ========== MI-BASED METRICS ==========
    print("\n  MI-based metrics:")
    metrics['AUPR_Success_MI'] = compute_aupr_success_mi(y_test, pred_test, mi_test)
    print(f"    AUPR-Success (MI): {metrics['AUPR_Success_MI']:.4f}")

    metrics['AUPR_Error_MI'] = compute_aupr_error_mi(y_test, pred_test, mi_test)
    print(f"    AUPR-Error (MI): {metrics['AUPR_Error_MI']:.4f}")

    metrics['AUROC_OOD_MI'] = compute_auroc_ood_mi(mi_test, mi_ood)
    print(f"    AUROC-OOD (MI): {metrics['AUROC_OOD_MI']:.4f}")

    metrics['AUPR_In_Domain_MI'] = compute_aupr_in_domain_mi(mi_test, mi_ood)
    print(f"    AUPR-In-Domain (MI): {metrics['AUPR_In_Domain_MI']:.4f}")

    metrics['AUPR_OOD_MI'] = compute_aupr_ood_mi(mi_test, mi_ood)
    print(f"    AUPR-OOD (MI): {metrics['AUPR_OOD_MI']:.4f}")
    # ========== STD-BASED METRICS (for comparison) ==========
    #print("\n  STD-based metrics (for comparison):")
    metrics['AUPR_Success_STD'] = compute_aupr_success_mi(y_test, pred_test, std_test)
    #print(f"    AUPR-Success (STD): {metrics['AUPR_Success_STD']:.4f}")

    metrics['AUPR_Error_STD'] = compute_aupr_error_mi(y_test, pred_test, std_test)
    #print(f"    AUPR-Error (STD): {metrics['AUPR_Error_STD']:.4f}")

    metrics['AUROC_OOD_STD'] = compute_auroc_ood_mi(std_test, std_ood)
    #print(f"    AUROC-OOD (STD): {metrics['AUROC_OOD_STD']:.4f}")

    metrics['AUPR_In_Domain_STD'] = compute_aupr_in_domain_mi(std_test, std_ood)
    #print(f"    AUPR-In-Domain (STD): {metrics['AUPR_In_Domain_STD']:.4f}")

    metrics['AUPR_OOD_STD'] = compute_aupr_ood_mi(std_test, std_ood)
    #print(f"    AUPR-OOD (STD): {metrics['AUPR_OOD_STD']:.4f}")

    # ========== ECE ==========
    ece_results = find_best_ece(y_test, pred_test)
    metrics['ECE_best'] = ece_results['best_ece']
    metrics['ECE_best_config'] = ece_results['best_config']
    print(f"\n  ECE (best): {metrics['ECE_best']:.4f} [{metrics['ECE_best_config']}]")

    # ========== Statistics ==========
    metrics['Mean_MI_In'] = mi_test.mean()
    metrics['Mean_MI_OOD'] = mi_ood.mean()
    metrics['MI_Ratio'] = mi_ood.mean() / (mi_test.mean() + 1e-10)

    #metrics['Mean_STD_In'] = std_test.mean()
    #metrics['Mean_STD_OOD'] = std_ood.mean()
    #metrics['STD_Ratio'] = std_ood.mean() / (std_test.mean() + 1e-10)

    print(f"\n  Uncertainty Statistics:")
    print(f"    Mean MI (In): {metrics['Mean_MI_In']:.4f}")
    print(f"    Mean MI (OOD): {metrics['Mean_MI_OOD']:.4f}")
    print(f"    MI Ratio (OOD/In): {metrics['MI_Ratio']:.2f}x")
    #print(f"    Mean STD (In): {metrics['Mean_STD_In']:.4f}")
    #print(f"    Mean STD (OOD): {metrics['Mean_STD_OOD']:.4f}")
    #print(f"    STD Ratio (OOD/In): {metrics['STD_Ratio']:.2f}x")

    return metrics


def evaluate_all_models_with_mi_v2(models_dict: Dict,
                                     X_test,
                                     y_test,
                                     X_ood,
                                     y_ood,
                                     n_samples: int = 512,
                                     seed: int = 42) -> pd.DataFrame:
    """
    Evaluate all models (deterministic, Bayesian, and Deep Ensemble) using proper information-theoretic MI.

    Parameters:
    -----------
    models_dict : dict of {model_name: model or list_of_models}
                  For deterministic/Bayesian models, value is a single model
                  For ensemble, value is a list of models
    X_test : test set features
    y_test : test set labels
    X_ood : OOD features
    y_ood : OOD labels
    n_samples : number of MC samples
    seed : random seed

    Returns:
    --------
    df : pandas DataFrame with all metrics for all models
    """
    results = {}

    for model_name, model in models_dict.items():
        tf.keras.backend.clear_session()
        gc.collect()
        tf.random.set_seed(seed)
        np.random.seed(seed)

        # Convert to numpy arrays
        X_test_np = X_test.values if hasattr(X_test, 'values') else X_test
        y_test_np = y_test.values if hasattr(y_test, 'values') else y_test
        X_ood_np = X_ood.values if hasattr(X_ood, 'values') else X_ood
        y_ood_np = y_ood.values if hasattr(y_ood, 'values') else y_ood

        metrics = evaluate_all_metrics_with_mi_v2(
            model, X_test_np, y_test_np, X_ood_np, y_ood_np,
            model_name, n_samples, seed
        )
        results[model_name] = metrics

    # Convert to DataFrame
    df = pd.DataFrame(results).T

    # Reorder columns for readability
    core_columns = [
        'AUROC',
        'AUPR_Success_MI',
        'AUPR_Error_MI',
        'AUROC_OOD_MI',
        'AUPR_In_Domain_MI',
        'AUPR_OOD_MI',
        #'AUPR_Success_STD',
       # #AUPR_Error_STD',
        #'AUROC_OOD_STD',
        #'AUPR_In_Domain_STD',
        #'AUPR_OOD_STD',
        'ECE_best',
        'ECE_best_config',
        'NLL',
        'Mean_MI_In',
        'Mean_MI_OOD',
        'MI_Ratio',
        #'Mean_STD_In',
        #'Mean_STD_OOD',
        #'STD_Ratio'
    ]

    # Only include columns that exist
    column_order = [col for col in core_columns if col in df.columns]
    df = df[column_order]

    return df


def rank_sweep_lowrank_gauss(
    ranks1,
    ranks2,
    X_train, y_train,
    X_test, y_test,
    X_ood, y_ood,
    feature_dim: int,
    class_weight: dict,
    det_model: tf.keras.Model | None = None,
    epochs: int = 40,
    batch_size: int = 128,
    subset_frac: float = 0.2,
    seed: int = 42,
):
    """
    Perform a grid sweep over (rank1, rank2) for build_lowrank_gauss.
    Trains on a subset of the train data for fewer epochs, then evaluates
    NLL/ECE + OOD metrics via evaluate_all_metrics_silent.
    """

    # ----- build a fixed subset for speed -----
    rng = np.random.RandomState(seed)
    n_train = len(X_train)
    subset_size = max(1, int(subset_frac * n_train))
    idx = rng.choice(n_train, size=subset_size, replace=False)
    x_sub = X_train.values[idx]
    y_sub = y_train.values[idx]

    results = []  # each entry: dict with rank1, rank2, metrics

    for r1 in ranks1:
        for r2 in ranks2:
            name = f"LowRank_r1={r1}_r2={r2}"
            print(f"\n=== Training {name} on subset (size={subset_size}) ===")

            # Reset seeds for reproducibility
            tf.keras.backend.clear_session()
            gc.collect()
            tf.random.set_seed(seed)
            np.random.seed(seed)

            # Build model with chosen ranks and optional deterministic init
            model = build_lowrank_gauss(
                input_dim=feature_dim,
                rank1=r1,
                rank2=r2,
                init_from_deterministic= None#det_model,
            )

            # Train on subset, validate on full val set
            history = model.fit(
                x_sub, y_sub,
                validation_data=(X_test.values, y_test.values),
                epochs=epochs,
                batch_size=batch_size,
                class_weight=class_weight,
                shuffle=True,
                verbose=0,
            )

            # Evaluate on test + OOD
            metrics = evaluate_all_metrics_silent(
                model,
                X_test.values, y_test.values,
                X_ood.values,  y_ood.values,
                model_name=name,
                n_samples=256,
                seed=seed,
            )

            # get predicted probabilities on the test set
            pred_test,_,_ = mc_predictions_with_mi_v2(model, X_test, n_samples=256, seed=seed)
            nll = compute_nll(y_test.values, pred_test)

            metrics["NLL"] = nll
            metrics["rank1"] = r1
            metrics["rank2"] = r2

            results.append(metrics)

    return results


def mc_predictions_lr_bayes_ensemble(
    ensemble,           # list of Bayesian low-rank models
    X: np.ndarray,
    n_samples: int = 256,
    seed: int = 42,
):
    """
    MC predictions for an ensemble of low-rank Bayesian models.

    For each member m in the ensemble:
      - call mc_predictions_with_mi_v2(m, X, n_samples, seed+idx)
        which returns (mean_m, std_m, mi_m) using the proper MI computation.
    Then:
      - average member means to get the ensemble predictive mean.
      - combine member predictive variances via the law of total variance.
      - aggregate member MI values (e.g., mean MI across members per point).

    Returns:
      pred_mean : (N,)
      pred_std  : (N,)
      mi_ens    : (N,)  # aggregated MI using the same MI definition as v2
    """
    tf.random.set_seed(seed)
    np.random.seed(seed)

    member_means = []
    member_vars  = []
    member_mis   = []

    for idx, m in enumerate(ensemble):
        mean_m, std_m, mi_m = mc_predictions_with_mi_v2(
            m, X, n_samples=n_samples, seed=seed + idx
        )  # mean_m, std_m, mi_m all shape (N,)

        member_means.append(mean_m[None, :])   # (1, N)
        member_vars.append((std_m**2)[None, :])  # (1, N)
        member_mis.append(mi_m[None, :])       # (1, N)

    member_means = np.concatenate(member_means, axis=0)  # (M, N)
    member_vars  = np.concatenate(member_vars,  axis=0)  # (M, N)
    member_mis   = np.concatenate(member_mis,   axis=0)  # (M, N)

    # Ensemble predictive mean: average over members
    pred_mean = member_means.mean(axis=0)                # (N,)

    # Law of total variance across members:
    # Var(Y) = E_m[Var(Y|m)] + Var_m(E[Y|m])
    aleatoric = member_vars.mean(axis=0)                 # (N,)
    epistemic = member_means.var(axis=0)                 # (N,)
    total_var = aleatoric + epistemic
    pred_std  = np.sqrt(total_var)                       # (N,)

    # Aggregate MI from members; simplest is mean MI per point
    mi_ens = member_mis.mean(axis=0)                     # (N,)

    return pred_mean, pred_std, mi_ens

def evaluate_lr_bayes_ensemble(
    ensemble,                  # list of Bayesian low-rank models
    X_test: np.ndarray,
    y_test: np.ndarray,
    X_ood: np.ndarray,
    y_ood: np.ndarray,
    model_name: str,
    n_samples: int = 512,
    seed: int = 42,
) -> Dict[str, float]:
    """
    Evaluate an ensemble of low-rank Bayesian models with MC per member.
    Uses the same MI definition as mc_predictions_with_mi_v2.
    """

    # In-domain predictions (uses proper MI via mc_predictions_with_mi_v2)
    pred_test, std_test, mi_test = mc_predictions_lr_bayes_ensemble(
        ensemble, X_test, n_samples=n_samples, seed=seed
    )

    # OOD predictions
    pred_ood, std_ood, mi_ood = mc_predictions_lr_bayes_ensemble(
        ensemble, X_ood, n_samples=n_samples, seed=seed
    )

    metrics = {}

    # Classification performance
    metrics['AUROC'] = compute_auroc(y_test, pred_test)

    # MI-based metrics
    metrics['AUPR_Success_MI']   = compute_aupr_success_mi(y_test, pred_test, mi_test)
    metrics['AUPR_Error_MI']     = compute_aupr_error_mi(y_test, pred_test, mi_test)
    metrics['AUROC_OOD_MI']      = compute_auroc_ood_mi(mi_test, mi_ood)
    metrics['AUPR_In_Domain_MI'] = compute_aupr_in_domain_mi(mi_test, mi_ood)
    metrics['AUPR_OOD_MI']       = compute_aupr_ood_mi(mi_test, mi_ood)

    # STD-based metrics (for comparison)
    #metrics['AUPR_Success_STD']   = compute_aupr_success_mi(y_test, pred_test, std_test)
    #metrics['AUPR_Error_STD']     = compute_aupr_error_mi(y_test, pred_test, std_test)
    #metrics['AUROC_OOD_STD']      = compute_auroc_ood_mi(std_test, std_ood)
    #metrics['AUPR_In_Domain_STD'] = compute_aupr_in_domain_mi(std_test, std_ood)
    #metrics['AUPR_OOD_STD']       = compute_aupr_ood_mi(std_test, std_ood)

    # ECE
    ece_results = find_best_ece(y_test, pred_test)
    metrics['ECE_best']        = ece_results['best_ece']
    metrics['ECE_best_config'] = ece_results['best_config']

    # Uncertainty statistics
    metrics['Mean_MI_In']   = mi_test.mean()
    metrics['Mean_MI_OOD']  = mi_ood.mean()
    metrics['MI_Ratio']     = mi_ood.mean() / (mi_test.mean() + 1e-10)
    #metrics['Mean_STD_In']  = std_test.mean()
    #metrics['Mean_STD_OOD'] = std_ood.mean()
    #metrics['STD_Ratio']    = std_ood.mean() / (std_test.mean() + 1e-10)
    nll = compute_nll(y_test.values, pred_test)
    metrics["NLL"] = nll
    print(f"    NLL: {metrics['NLL']:.4f}")
    return metrics
