"""
evaluation.py - Evaluation utilities for Bayesian Transformers.

This module provides:
- Monte Carlo sampling for Bayesian model predictions
- Ensemble prediction aggregation for Deep Ensembles
- Comprehensive metrics: AUROC, AUPR, ECE, NLL, Brier Score
- OOD detection metrics based on MI and STD
- Unified evaluation pipeline that respects each model's specifics

Key design principle:
- Deep ensemble models are evaluated using ensemble prediction logic (aggregate over members)
- Variational/Bayesian models are evaluated using MC sampling logic (multiple stochastic forward passes)
"""

import gc
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score

from modules.config import set_seed, N_MC_SAMPLES, N_ECE_BINS, IMPROVED_N_MC_SAMPLES
from modules.bayesian_layers import set_dropout_active
from modules.model_builders import DeepEnsemble


# =============================================================================
# MONTE CARLO SAMPLING FOR BAYESIAN MODELS
# =============================================================================

def mc_predictions_with_mi(model, X, n_samples=50, epsilon=1e-10):
    """
    Monte Carlo sampling for Bayesian model predictions.

    Performs multiple stochastic forward passes through the model and computes
    statistics (mean, variance, mutual information, std) over the samples.

    Parameters
    ----------
    model : tf.keras.Model
        Bayesian model with variational layers
    X : dict
        Input data dictionary with 'input_ids' and 'attention_mask'
    n_samples : int
        Number of MC samples (default: 50)
    epsilon : float
        Small value for numerical stability

    Returns
    -------
    predictions : np.ndarray
        Mean prediction (probability of class 1) for each sample
    variance : np.ndarray
        Variance across MC samples
    mi : np.ndarray
        Mutual Information approximation
    std : np.ndarray
        Standard deviation across MC samples
    """
    # Detect input length
    set_seed()
    if isinstance(X, dict):
        N = len(X['input_ids'])
    else:
        N = len(X)

    print(f"  Sampling {n_samples} times for {N} inputs...", end=" ")

    # Pre-allocate: We want ONE probability per sample (P(class=1))
    all_outputs = np.zeros((n_samples, N, 1))

    for i in range(n_samples):
        # Raw output: Shape (N, 2) - run with training=True for stochastic sampling
        logits = model(X, training=True)

        # Apply Softmax and take Class 1
        probs = tf.nn.softmax(logits, axis=-1).numpy()

        # Take only the second column (Class 1 probability) and reshape to (N, 1)
        p_class_1 = probs[:, 1].reshape(-1, 1)

        all_outputs[i] = p_class_1

    print("Done.")

    # Compute Statistics
    predictions = all_outputs.mean(axis=0).squeeze()  # Mean
    variance = all_outputs.var(axis=0).squeeze()      # Variance
    std = all_outputs.std(axis=0).squeeze()           # Std Dev

    # Mutual Information (MI)
    # Formula: MI approx Var[p] / (2 * mu * (1 - mu))
    denominator = 2 * predictions * (1 - predictions) + epsilon
    mi = variance / denominator
    mi = np.clip(mi, 0, 100)

    return predictions, variance, mi, std


# =============================================================================
# ENSEMBLE PREDICTIONS
# =============================================================================

def ensemble_predictions_with_uncertainty(ensemble, X, epsilon=1e-10):
    """
    Compute ensemble predictions with uncertainty estimates.

    Unlike MC-based sampling used for Bayesian models, deep ensembles
    derive uncertainty from the disagreement between ensemble members.

    Parameters
    ----------
    ensemble : DeepEnsemble
        Trained deep ensemble model
    X : dict
        Input data dictionary with 'input_ids' and 'attention_mask'
    epsilon : float
        Small value for numerical stability

    Returns
    -------
    predictions : np.ndarray
        Mean prediction (probability of class 1) for each sample
    variance : np.ndarray
        Variance across ensemble members
    mi : np.ndarray
        Mutual Information approximation
    std : np.ndarray
        Standard deviation across ensemble members (epistemic uncertainty)
    """
    set_seed()
    # Get predictions from all ensemble members
    mean_pred, std_pred, individual_preds = ensemble.predict(X, return_individual=True)

    # Variance is std^2
    variance = std_pred ** 2

    # Mutual Information approximation
    # For ensembles, MI captures epistemic uncertainty through member disagreement
    # Formula: MI approx Var[p] / (2 * mu * (1 - mu))
    denominator = 2 * mean_pred * (1 - mean_pred) + epsilon
    mi = variance / denominator
    mi = np.clip(mi, 0, 100)

    return mean_pred, variance, mi, std_pred


def mc_predictions_ensemble_compatible(model, X, n_samples=50, epsilon=1e-10):
    """
    Unified prediction function that works for both regular models and ensembles.

    For DeepEnsemble objects, uses ensemble_predictions_with_uncertainty.
    For regular models, uses the standard MC sampling approach.

    Parameters
    ----------
    model : keras.Model or DeepEnsemble
        The model to generate predictions from
    X : dict
        Input data dictionary
    n_samples : int
        Number of MC samples (only used for non-ensemble models)
    epsilon : float
        Small value for numerical stability

    Returns
    -------
    predictions, variance, mi, std : np.ndarray
        Same format as mc_predictions_with_mi
    """
    set_seed()
    if isinstance(model, DeepEnsemble):
        return ensemble_predictions_with_uncertainty(model, X, epsilon)
    else:
        # Fall back to the original MC sampling for Bayesian models
        return mc_predictions_with_mi(model, X, n_samples, epsilon)


# =============================================================================
# METRIC FUNCTIONS
# =============================================================================

def compute_aupr_success_mi(y_true, y_pred, mi):
    """
    Compute AUPR for success prediction using MI as uncertainty score.

    Low MI indicates low uncertainty, which should correlate with correct predictions.

    Parameters
    ----------
    y_true : np.ndarray
        True labels
    y_pred : np.ndarray
        Predicted probabilities
    mi : np.ndarray
        Mutual information values

    Returns
    -------
    float
        AUPR for success prediction
    """
    predictions_binary = (y_pred >= 0.5).astype(int)
    correct = (predictions_binary == y_true).astype(int)
    scores = -mi  # Low MI = high score for correct predictions
    return average_precision_score(correct, scores)


def compute_aupr_error_mi(y_true, y_pred, mi):
    """
    Compute AUPR for error prediction using MI as uncertainty score.

    High MI indicates high uncertainty, which should correlate with errors.

    Parameters
    ----------
    y_true : np.ndarray
        True labels
    y_pred : np.ndarray
        Predicted probabilities
    mi : np.ndarray
        Mutual information values

    Returns
    -------
    float
        AUPR for error prediction
    """
    predictions_binary = (y_pred >= 0.5).astype(int)
    errors = (predictions_binary != y_true).astype(int)
    scores = mi  # High MI = high score for errors
    return average_precision_score(errors, scores)


def compute_auroc_ood_mi(mi_in, mi_out):
    """
    Compute AUROC for OOD detection using MI.

    OOD samples should have higher MI than in-distribution samples.

    Parameters
    ----------
    mi_in : np.ndarray
        MI values for in-distribution samples
    mi_out : np.ndarray
        MI values for OOD samples

    Returns
    -------
    float
        AUROC for OOD detection
    """
    all_mi = np.concatenate([mi_in, mi_out])
    labels = np.concatenate([np.zeros(len(mi_in)), np.ones(len(mi_out))])
    return roc_auc_score(labels, all_mi)


def compute_aupr_in_domain_mi(mi_in, mi_out):
    """
    Compute AUPR for in-domain detection using MI.

    Parameters
    ----------
    mi_in : np.ndarray
        MI values for in-distribution samples
    mi_out : np.ndarray
        MI values for OOD samples

    Returns
    -------
    float
        AUPR for in-domain detection
    """
    all_mi = np.concatenate([mi_in, mi_out])
    in_domain_labels = np.concatenate([np.ones(len(mi_in)), np.zeros(len(mi_out))])
    scores = -all_mi
    return average_precision_score(in_domain_labels, scores)


def compute_aupr_ood_mi(mi_in, mi_out):
    """
    Compute AUPR for OOD detection using MI.

    Parameters
    ----------
    mi_in : np.ndarray
        MI values for in-distribution samples
    mi_out : np.ndarray
        MI values for OOD samples

    Returns
    -------
    float
        AUPR for OOD detection
    """
    all_mi = np.concatenate([mi_in, mi_out])
    ood_labels = np.concatenate([np.zeros(len(mi_in)), np.ones(len(mi_out))])
    scores = all_mi
    return average_precision_score(ood_labels, scores)


def compute_ece_equal_mass(y_true, y_pred, n_bins=15):
    """
    Compute Expected Calibration Error with Equal-Mass (Quantile) Binning.

    Parameters
    ----------
    y_true : np.ndarray
        True labels
    y_pred : np.ndarray
        Predicted probabilities
    n_bins : int
        Number of bins (default: 15)

    Returns
    -------
    float
        ECE value
    """
    quantiles = np.linspace(0, 1, n_bins + 1)
    bin_boundaries = np.quantile(y_pred, quantiles)
    bin_boundaries[0] = 0.0
    bin_boundaries[-1] = 1.0
    bin_boundaries = np.unique(bin_boundaries)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (y_pred >= bin_lower) & (y_pred <= bin_upper)
        prop_in_bin = in_bin.mean()
        if prop_in_bin > 0:
            confidence_in_bin = y_pred[in_bin].mean()
            accuracy_in_bin = y_true[in_bin].mean()
            ece += prop_in_bin * abs(accuracy_in_bin - confidence_in_bin)
    return ece


def compute_ece_equal_width(y_true, y_pred, n_bins=15):
    """
    Compute Expected Calibration Error with Equal-Width Binning.

    Parameters
    ----------
    y_true : np.ndarray
        True labels
    y_pred : np.ndarray
        Predicted probabilities
    n_bins : int
        Number of bins (default: 15)

    Returns
    -------
    float
        ECE value
    """
    bin_boundaries = np.linspace(0.0, 1.0, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (y_pred > bin_lower) & (y_pred <= bin_upper)
        prop_in_bin = in_bin.mean()
        if prop_in_bin > 0:
            confidence_in_bin = y_pred[in_bin].mean()
            accuracy_in_bin = y_true[in_bin].mean()
            ece += prop_in_bin * abs(accuracy_in_bin - confidence_in_bin)
    return ece


def find_best_ece_config(y_true, y_pred, bin_options=[10, 15, 20, 30],
                         strategies=['equal_width', 'equal_mass']):
    """
    Find the ECE configuration that gives the lowest ECE value.

    This function tests different combinations of binning strategies and bin counts,
    returning the configuration that minimizes ECE for the given predictions.

    Parameters
    ----------
    y_true : np.ndarray
        True labels
    y_pred : np.ndarray
        Predicted probabilities
    bin_options : list of int
        List of bin counts to try (default: [10, 15, 20, 30])
    strategies : list of str
        List of binning strategies to try (default: ['equal_width', 'equal_mass'])

    Returns
    -------
    dict
        Best configuration with keys:
        - 'strategy': str (best binning strategy)
        - 'n_bins': int (best number of bins)
        - 'ece': float (lowest ECE achieved)
        - 'all_results': dict (all tested configurations)
    """
    results = {}
    best_ece = float('inf')
    best_config = None

    for strategy in strategies:
        for n_bins in bin_options:
            if strategy == 'equal_width':
                ece = compute_ece_equal_width(y_true, y_pred, n_bins)
            elif strategy == 'equal_mass':
                ece = compute_ece_equal_mass(y_true, y_pred, n_bins)
            else:
                continue

            config_name = f"{strategy}_{n_bins}bins"
            results[config_name] = ece

            if ece < best_ece:
                best_ece = ece
                best_config = {
                    'strategy': strategy,
                    'n_bins': n_bins,
                    'ece': ece,
                }

    best_config['all_results'] = results
    return best_config


def compute_ece_with_config(y_true, y_pred, config):
    """
    Compute ECE using a specific configuration.

    Parameters
    ----------
    y_true : np.ndarray
        True labels
    y_pred : np.ndarray
        Predicted probabilities
    config : dict
        Configuration dict with 'strategy' and 'n_bins' keys

    Returns
    -------
    float
        ECE value
    """
    if config['strategy'] == 'equal_width':
        return compute_ece_equal_width(y_true, y_pred, config['n_bins'])
    elif config['strategy'] == 'equal_mass':
        return compute_ece_equal_mass(y_true, y_pred, config['n_bins'])
    else:
        raise ValueError(f"Unknown strategy: {config['strategy']}")


def compute_entropy_values(pred_probs, epsilon=1e-12):
    """
    Compute per-sample predictive entropy for binary classification.

    Predictive entropy quantifies uncertainty in a binary prediction.
    For a predicted probability p of class 1, the entropy is
    H(p) = -p*log(p) - (1-p)*log(1-p).

    Parameters
    ----------
    pred_probs : np.ndarray
        Array of predicted probabilities of the positive class
    epsilon : float
        Small value to clip probabilities away from 0 and 1

    Returns
    -------
    np.ndarray
        Array of entropy values for each prediction
    """
    p = np.clip(pred_probs, epsilon, 1 - epsilon)
    return -p * np.log(p) - (1 - p) * np.log(1 - p)


def compute_nll(y_true, y_pred, epsilon=1e-12):
    """
    Compute the negative log-likelihood (binary cross entropy).

    Parameters
    ----------
    y_true : np.ndarray
        Array of true labels (0 or 1)
    y_pred : np.ndarray
        Array of predicted probabilities of the positive class
    epsilon : float
        Small value to clip probabilities away from 0 and 1

    Returns
    -------
    float
        The average negative log-likelihood
    """
    p = np.clip(y_pred, epsilon, 1 - epsilon)
    losses = -(y_true * np.log(p) + (1 - y_true) * np.log(1 - p))
    return losses.mean()


def compute_brier_score(y_true, y_pred):
    """
    Compute the Brier score for binary classification.

    The Brier score is the mean squared difference between predicted
    probabilities and actual outcomes.

    Parameters
    ----------
    y_true : np.ndarray
        Array of true labels (0 or 1)
    y_pred : np.ndarray
        Array of predicted probabilities of the positive class

    Returns
    -------
    float
        The mean Brier score
    """
    return ((y_pred - y_true) ** 2).mean()


# =============================================================================
# MAIN EVALUATION FUNCTIONS
# =============================================================================

def evaluate_all_metrics_mi(model, X_test, y_test, X_ood, y_ood, model_name, n_samples=50):
    """
    Evaluate a single Bayesian model with all metrics using MC sampling.

    Parameters
    ----------
    model : tf.keras.Model
        The Bayesian model to evaluate
    X_test : dict
        In-distribution test data
    y_test : np.ndarray
        In-distribution test labels
    X_ood : dict
        Out-of-distribution data
    y_ood : np.ndarray
        OOD labels
    model_name : str
        Name for display
    n_samples : int
        Number of MC samples

    Returns
    -------
    dict
        Dictionary of all computed metrics
    """
    set_seed()
    print(f"\n{'='*80}")
    print(f"Evaluating: {model_name}")
    print(f"{'='*80}")
    set_dropout_active(model, active=False)

    # In-Domain Predictions
    print("Computing in-domain predictions...")
    pred_test, var_test, mi_test, std_test = mc_predictions_with_mi(model, X_test, n_samples)

    # OOD Predictions
    print("Computing OOD predictions...")
    pred_ood, var_ood, mi_ood, std_ood = mc_predictions_with_mi(model, X_ood, n_samples)

    # Compute Metrics
    print("Computing metrics...")
    metrics = {}

    # --- PERFORMANCE METRICS ---
    pred_binary = (pred_test >= 0.5).astype(int)
    metrics['Accuracy'] = round(accuracy_score(y_test, pred_binary), 3)
    metrics['AUROC_Classification'] = roc_auc_score(y_test, pred_test)

    # --- MI-BASED METRICS ---
    metrics['AUPR_Success_MI'] = compute_aupr_success_mi(y_test, pred_test, mi_test)
    metrics['AUPR_Error_MI'] = compute_aupr_error_mi(y_test, pred_test, mi_test)
    metrics['AUROC_OOD_MI'] = compute_auroc_ood_mi(mi_test, mi_ood)
    metrics['AUPR_In_MI'] = compute_aupr_in_domain_mi(mi_test, mi_ood)
    metrics['AUPR_Out_MI'] = compute_aupr_ood_mi(mi_test, mi_ood)

    # --- STD-BASED METRICS ---
    metrics['AUPR_Success_STD'] = compute_aupr_success_mi(y_test, pred_test, std_test)
    metrics['AUPR_Error_STD'] = compute_aupr_error_mi(y_test, pred_test, std_test)
    metrics['AUROC_OOD_STD'] = compute_auroc_ood_mi(std_test, std_ood)
    metrics['AUPR_In_STD'] = compute_aupr_in_domain_mi(std_test, std_ood)
    metrics['AUPR_Out_STD'] = compute_aupr_ood_mi(std_test, std_ood)

    # --- ECE ---
    metrics['ECE'] = compute_ece_equal_mass(y_test, pred_test, n_bins=N_ECE_BINS)

    # --- STATISTICS ---
    metrics['Mean_MI_In'] = mi_test.mean()
    metrics['Mean_MI_Out'] = mi_ood.mean()
    metrics['MI_Ratio'] = mi_ood.mean() / (mi_test.mean() + 1e-10)

    metrics['Mean_STD_In'] = std_test.mean()
    metrics['Mean_STD_Out'] = std_ood.mean()
    metrics['STD_Ratio'] = std_ood.mean() / (std_test.mean() + 1e-10)

    entropy_in_values = compute_entropy_values(pred_test)
    entropy_out_values = compute_entropy_values(pred_ood)
    metrics['Mean_Entropy_In'] = entropy_in_values.mean()
    metrics['Mean_Entropy_Out'] = entropy_out_values.mean()
    metrics['Entropy_Ratio'] = metrics['Mean_Entropy_Out'] / (metrics['Mean_Entropy_In'] + 1e-10)

    # Negative Log-Likelihood (NLL) and Brier Score
    metrics['NLL'] = compute_nll(y_test, pred_test)
    metrics['Brier_Score'] = compute_brier_score(y_test, pred_test)

    return metrics


def evaluate_all_metrics_unified(model, X_test, y_test, X_ood, y_ood, model_name, n_samples=50, ece_config=None):
    """
    Evaluate model with all metrics, supporting both ensemble and Bayesian models.

    This function automatically detects if the model is a DeepEnsemble and uses
    the appropriate prediction method (ensemble disagreement vs MC sampling).

    Parameters
    ----------
    model : keras.Model or DeepEnsemble
        The model to evaluate
    X_test : dict
        In-distribution test data
    y_test : np.ndarray
        In-distribution test labels
    X_ood : dict
        Out-of-distribution data
    y_ood : np.ndarray
        OOD labels
    model_name : str
        Name for display
    n_samples : int
        Number of MC samples (only for Bayesian models)

    Returns
    -------
    dict
        Dictionary of all computed metrics
    """

    print(f"\n{'='*80}")
    print(f"Evaluating: {model_name}")
    print(f"{'='*80}")

    set_seed()
    is_ensemble = isinstance(model, DeepEnsemble)

    if is_ensemble:
        # Use ensemble-specific prediction (based on member disagreement)
        print("Computing in-domain predictions (ensemble)...")
        pred_test, var_test, mi_test, std_test = ensemble_predictions_with_uncertainty(model, X_test)

        print("Computing OOD predictions (ensemble)...")
        pred_ood, var_ood, mi_ood, std_ood = ensemble_predictions_with_uncertainty(model, X_ood)
    else:
        # Use MC sampling for Bayesian models
        set_dropout_active(model, active=False)

        print("Computing in-domain predictions (MC sampling)...")
        pred_test, var_test, mi_test, std_test = mc_predictions_with_mi(model, X_test, n_samples)

        print("Computing OOD predictions (MC sampling)...")
        pred_ood, var_ood, mi_ood, std_ood = mc_predictions_with_mi(model, X_ood, n_samples)

    # Compute Metrics
    print("Computing metrics...")
    metrics = {}

    # --- PERFORMANCE METRICS ---
    pred_binary = (pred_test >= 0.5).astype(int)
    metrics['Accuracy'] = round(accuracy_score(y_test, pred_binary), 3)
    metrics['AUROC_Classification'] = roc_auc_score(y_test, pred_test)

    # --- MI-BASED METRICS ---
    metrics['AUPR_Success_MI'] = compute_aupr_success_mi(y_test, pred_test, mi_test)
    metrics['AUPR_Error_MI'] = compute_aupr_error_mi(y_test, pred_test, mi_test)
    metrics['AUROC_OOD_MI'] = compute_auroc_ood_mi(mi_test, mi_ood)
    metrics['AUPR_In_MI'] = compute_aupr_in_domain_mi(mi_test, mi_ood)
    metrics['AUPR_Out_MI'] = compute_aupr_ood_mi(mi_test, mi_ood)

    # --- STD-BASED METRICS ---
    metrics['AUPR_Success_STD'] = compute_aupr_success_mi(y_test, pred_test, std_test)
    metrics['AUPR_Error_STD'] = compute_aupr_error_mi(y_test, pred_test, std_test)
    metrics['AUROC_OOD_STD'] = compute_auroc_ood_mi(std_test, std_ood)
    metrics['AUPR_In_STD'] = compute_aupr_in_domain_mi(std_test, std_ood)
    metrics['AUPR_Out_STD'] = compute_aupr_ood_mi(std_test, std_ood)

    # --- CALIBRATION METRICS ---
    if ece_config is not None:
        metrics['ECE'] = compute_ece_with_config(y_test, pred_test, ece_config)
    else:
        metrics['ECE'] = compute_ece_equal_mass(y_test, pred_test, n_bins=N_ECE_BINS)
    metrics['NLL'] = compute_nll(y_test, pred_test)
    metrics['Brier_Score'] = compute_brier_score(y_test, pred_test)

    # --- UNCERTAINTY STATISTICS ---
    metrics['Mean_MI_In'] = mi_test.mean()
    metrics['Mean_MI_Out'] = mi_ood.mean()
    metrics['MI_Ratio'] = mi_ood.mean() / (mi_test.mean() + 1e-10)

    metrics['Mean_STD_In'] = std_test.mean()
    metrics['Mean_STD_Out'] = std_ood.mean()
    metrics['STD_Ratio'] = std_ood.mean() / (std_test.mean() + 1e-10)

    entropy_in_values = compute_entropy_values(pred_test)
    entropy_out_values = compute_entropy_values(pred_ood)
    metrics['Mean_Entropy_In'] = entropy_in_values.mean()
    metrics['Mean_Entropy_Out'] = entropy_out_values.mean()
    metrics['Entropy_Ratio'] = metrics['Mean_Entropy_Out'] / (metrics['Mean_Entropy_In'] + 1e-10)

    return metrics


def evaluate_all_models_mi(models_dict, X_test, y_test, X_ood, y_ood, n_samples=50):
    """
    Evaluate all Bayesian models in a dictionary using MC sampling.

    Parameters
    ----------
    models_dict : dict
        Dictionary mapping model names to models
    X_test, y_test : test data
    X_ood, y_ood : OOD data
    n_samples : int
        Number of MC samples

    Returns
    -------
    pd.DataFrame
        DataFrame with all metrics for all models
    """
    results = {}
    for model_name, model in models_dict.items():
        set_dropout_active(model, active=False)
        set_seed()
        metrics = evaluate_all_metrics_mi(
            model, X_test, y_test, X_ood, y_ood, model_name, n_samples
        )
        results[model_name] = metrics
    return pd.DataFrame(results).T


def evaluate_all_models_unified(models_dict, X_test, y_test, X_ood, y_ood, n_samples=50, ece_config=None):
    """
    Evaluate all models in a dictionary with unified interface.

    Works with both DeepEnsemble objects and regular Keras models.

    Parameters
    ----------
    models_dict : dict
        Dictionary mapping model names to models
    X_test, y_test : test data
    X_ood, y_ood : OOD data
    n_samples : int
        Number of MC samples for Bayesian models
    ece_config : dict, optional
        ECE configuration with 'strategy' and 'n_bins' keys

    Returns
    -------
    pd.DataFrame
        DataFrame with all metrics for all models
    """
    results = {}
    for model_name, model in models_dict.items():
        set_seed()
        metrics = evaluate_all_metrics_unified(
            model, X_test, y_test, X_ood, y_ood, model_name, n_samples, ece_config
        )
        results[model_name] = metrics
    return pd.DataFrame(results).T


def evaluate_all_models_with_optimized_ece(models_dict, X_test, y_test, X_ood, y_ood, n_samples=50,
                                           reference_model_key=None, bin_options=[10, 15, 20, 30],
                                           strategies=['equal_width', 'equal_mass']):
    """
    Evaluate all models with ECE configuration optimized for a reference model.

    This function finds the ECE configuration (binning strategy + number of bins) that
    gives the lowest ECE for a reference model (typically the low-rank model), then
    uses that same configuration for all models to ensure fair comparison.

    Parameters
    ----------
    models_dict : dict
        Dictionary mapping model names to models
    X_test, y_test : test data
    X_ood, y_ood : OOD data
    n_samples : int
        Number of MC samples for Bayesian models (default: 50)
    reference_model_key : str, optional
        Name of the model to use for finding the best ECE config.
        If None, automatically finds the first model with "Low-Rank" in the name.
    bin_options : list of int
        List of bin counts to try (default: [10, 15, 20, 30])
    strategies : list of str
        List of binning strategies to try (default: ['equal_width', 'equal_mass'])

    Returns
    -------
    tuple
        (results_df, ece_config) where:
        - results_df: DataFrame with all metrics for all models
        - ece_config: dict with the optimal ECE configuration used
    """
    # Find reference model
    if reference_model_key is None:
        # Auto-detect low-rank model
        for key in models_dict.keys():
            if 'low-rank' in key.lower() or 'lowrank' in key.lower():
                reference_model_key = key
                break
        if reference_model_key is None:
            # Fallback: use first model
            reference_model_key = list(models_dict.keys())[0]
            print(f"Warning: Could not find 'Low-Rank' model, using '{reference_model_key}' as reference")

    print(f"\n{'='*80}")
    print(f"FINDING OPTIMAL ECE CONFIGURATION USING: {reference_model_key}")
    print(f"{'='*80}")

    # Get predictions from reference model
    reference_model = models_dict[reference_model_key]
    is_ensemble = isinstance(reference_model, DeepEnsemble)

    if is_ensemble:
        pred_test, _, _, _ = ensemble_predictions_with_uncertainty(reference_model, X_test)
    else:
        set_dropout_active(reference_model, active=False)
        pred_test, _, _, _ = mc_predictions_with_mi(reference_model, X_test, n_samples)

    # Find best ECE configuration
    best_config = find_best_ece_config(y_test, pred_test, bin_options, strategies)

    print(f"\nTested configurations:")
    for config_name, ece_value in sorted(best_config['all_results'].items(), key=lambda x: x[1]):
        print(f"  {config_name:<25} ECE = {ece_value:.4f}")

    print(f"\n{'='*80}")
    print(f"OPTIMAL ECE CONFIGURATION:")
    print(f"  Strategy: {best_config['strategy']}")
    print(f"  Number of bins: {best_config['n_bins']}")
    print(f"  ECE ({reference_model_key}): {best_config['ece']:.4f}")
    print(f"{'='*80}")

    # Evaluate all models with the optimal configuration
    print(f"\nEvaluating all models with optimized ECE configuration...")
    results_df = evaluate_all_models_unified(
        models_dict, X_test, y_test, X_ood, y_ood, n_samples,
        ece_config=best_config
    )

    return results_df, best_config


# =============================================================================
# RESULTS DISPLAY UTILITIES
# =============================================================================

def display_results(results_df, save_path=None):
    """
    Display and optionally save evaluation results.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with evaluation results
    save_path : str, optional
        Path to save CSV file
    """
    # Define columns to display
    cols_performance = ['Accuracy', 'AUROC_Classification', 'ECE', 'NLL', 'Brier_Score']
    cols_ood_mi = ['AUROC_OOD_MI', 'MI_Ratio', 'Mean_MI_In', 'Mean_MI_Out']
    cols_ood_std = ['AUROC_OOD_STD', 'STD_Ratio', 'Mean_STD_In', 'Mean_STD_Out']

    print("\n" + "="*80)
    print("FINAL RESULTS")
    print("="*80)

    print("\n--- Performance & Calibration Metrics ---")
    available_perf = [c for c in cols_performance if c in results_df.columns]
    print(results_df[available_perf].to_string())

    print("\n--- OOD Detection (MI-based) ---")
    available_mi = [c for c in cols_ood_mi if c in results_df.columns]
    print(results_df[available_mi].to_string())

    print("\n--- OOD Detection (STD-based) ---")
    available_std = [c for c in cols_ood_std if c in results_df.columns]
    print(results_df[available_std].to_string())

    if save_path:
        results_df.to_csv(save_path)
        print(f"\nSaved to '{save_path}'")


def print_key_findings(results_df):
    """
    Print a summary of key findings from evaluation results.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with evaluation results
    """
    print("\n" + "="*80)
    print("KEY FINDINGS SUMMARY")
    print("="*80)

    print("\n[ACCURACY] (higher is better):")
    acc_sorted = results_df['Accuracy'].sort_values(ascending=False)
    for i, (model, acc) in enumerate(acc_sorted.items(), 1):
        print(f"  {i}. {model}: {acc:.3f}")

    print("\n[CALIBRATION - ECE] (lower is better):")
    ece_sorted = results_df['ECE'].sort_values(ascending=True)
    for i, (model, ece) in enumerate(ece_sorted.items(), 1):
        print(f"  {i}. {model}: {ece:.4f}")

    print("\n[OOD DETECTION - AUROC STD] (higher is better):")
    auroc_sorted = results_df['AUROC_OOD_STD'].sort_values(ascending=False)
    for i, (model, auroc) in enumerate(auroc_sorted.items(), 1):
        print(f"  {i}. {model}: {auroc:.3f}")
   
    print("\n[UNCERTAINTY SEPARATION - STD Ratio] (higher is better):")
    ratio_sorted = results_df['STD_Ratio'].sort_values(ascending=False)
    for i, (model, ratio) in enumerate(ratio_sorted.items(), 1):
        print(f"  {i}. {model}: {ratio:.2f}")

