"""
visualization.py - Plotting and visualization utilities for Bayesian Transformers.

This module provides:
- Training convergence plots
- Model comparison bar charts
- Uncertainty distribution plots (ID vs OOD)
- Ensemble training visualization
- Results summary visualization
- Singular value analysis for rank selection
- Metrics comparison (bar charts, radar plots)
- ROC curves with confidence intervals
- Reliability/calibration diagrams
- Uncertainty decomposition plots
- Weight correlation heatmaps
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.metrics import roc_curve, auc

from modules.config import set_seed
from modules.bayesian_layers import set_dropout_active
from modules.model_builders import DeepEnsemble
from modules.evaluation import (
    mc_predictions_with_mi,
    ensemble_predictions_with_uncertainty,
)


# =============================================================================
# TRAINING CONVERGENCE PLOTS
# =============================================================================

def plot_icml_convergence(histories_dict, save_path=None):
    """
    Plot validation accuracy trajectories of multiple models.

    Parameters
    ----------
    histories_dict : dict
        Dictionary mapping model names to history dictionaries.
        Structure: {'Model Name': {'val_acc': [0.7, 0.8...], ...}, ...}
    save_path : str, optional
        Path to save the figure
    """
    sns.set_style("whitegrid")
    plt.figure(figsize=(10, 6))

    colors = ['black', 'blue', 'red', 'green', 'orange', 'purple']
    markers = ['o', 's', '^', 'D', 'v', 'p']

    for i, (name, metrics) in enumerate(histories_dict.items()):
        # Access the dictionary directly
        val_acc = metrics.get('val_accuracy') or metrics.get('val_acc')

        # Safety check in case the key is missing
        if val_acc is None:
            print(f"Warning: No validation accuracy found for {name}")
            continue

        epochs = range(1, len(val_acc) + 1)

        plt.plot(
            epochs,
            val_acc,
            label=f"{name} (Max: {max(val_acc):.2%})",
            color=colors[i % len(colors)],
            marker=markers[i % len(markers)],
            linewidth=2,
            markersize=8
        )

    plt.title("Convergence Comparison: Validation Accuracy", fontsize=14, fontweight='bold')
    plt.xlabel("Epochs", fontsize=12)
    plt.ylabel("Accuracy", fontsize=12)

    # Ensure integer ticks for the X-axis
    if 'epochs' in dir():
        plt.xticks(epochs)

    plt.legend(fontsize=10, loc='lower right')
    plt.ylim(0.70, 0.90)

    plt.axhline(y=0.80, color='gray', linestyle='--', alpha=0.5, label='Target (80%)')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


def plot_training_loss(histories_dict, save_path=None):
    """
    Plot training loss trajectories of multiple models.

    Parameters
    ----------
    histories_dict : dict
        Dictionary mapping model names to history dictionaries
    save_path : str, optional
        Path to save the figure
    """
    sns.set_style("whitegrid")
    plt.figure(figsize=(10, 6))

    colors = plt.cm.viridis(np.linspace(0, 0.8, len(histories_dict)))

    for i, (name, metrics) in enumerate(histories_dict.items()):
        loss = metrics.get('loss')
        if loss is None:
            print(f"Warning: No training loss found for {name}")
            continue

        epochs = range(1, len(loss) + 1)
        plt.plot(epochs, loss, color=colors[i], linewidth=2, label=name)

    plt.title("Training Loss Comparison", fontsize=14, fontweight='bold')
    plt.xlabel("Epochs", fontsize=12)
    plt.ylabel("Loss", fontsize=12)
    plt.legend(fontsize=10, loc='upper right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


# =============================================================================
# ENSEMBLE TRAINING VISUALIZATION
# =============================================================================

def plot_ensemble_training(ensemble_histories, title="Deep Ensemble Training", save_path=None):
    """
    Plot training curves for all ensemble members.

    Parameters
    ----------
    ensemble_histories : list
        List of history dictionaries, one per ensemble member
    title : str
        Plot title
    save_path : str, optional
        Path to save the figure
    """
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    colors = plt.cm.viridis(np.linspace(0, 0.8, len(ensemble_histories)))

    # Plot training loss
    ax1 = axes[0]
    for i, history in enumerate(ensemble_histories):
        epochs = range(1, len(history['loss']) + 1)
        ax1.plot(epochs, history['loss'],
                color=colors[i], alpha=0.7,
                label=f'Member {i+1}', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Training Loss', fontsize=12)
    ax1.set_title('Training Loss per Member', fontsize=14)
    ax1.legend(loc='upper right')
    ax1.grid(True, alpha=0.3)

    # Plot validation accuracy
    ax2 = axes[1]
    for i, history in enumerate(ensemble_histories):
        val_acc = history.get('val_acc') or history.get('val_accuracy')
        epochs = range(1, len(val_acc) + 1)
        ax2.plot(epochs, val_acc,
                color=colors[i], alpha=0.7,
                label=f'Member {i+1}', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Validation Accuracy', fontsize=12)
    ax2.set_title('Validation Accuracy per Member', fontsize=14)
    ax2.legend(loc='lower right')
    ax2.grid(True, alpha=0.3)

    plt.suptitle(title, fontsize=16, y=1.02)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()

    # Print final statistics
    final_accs = [h.get('val_acc', h.get('val_accuracy'))[-1]
                  for h in ensemble_histories]
    print(f"\nFinal Validation Accuracies:")
    for i, acc in enumerate(final_accs):
        print(f"  Member {i+1}: {acc:.4f}")
    print(f"  Mean: {np.mean(final_accs):.4f}")
    print(f"  Std:  {np.std(final_accs):.4f}")


# =============================================================================
# MODEL COMPARISON VISUALIZATION
# =============================================================================

def plot_model_comparison(results_df, save_path=None):
    """
    Create comprehensive visualization comparing all models.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with evaluation results
    save_path : str, optional
        Path to save the figure
    """
    sns.set_style("whitegrid")

    # Create figure with subplots
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))

    models = results_df.index.tolist()
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6B4C9A', '#1B998B']
    colors = colors[:len(models)]

    # 1. Accuracy
    ax = axes[0, 0]
    bars = ax.bar(models, results_df['Accuracy'], color=colors, alpha=0.8, edgecolor='black')
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.set_title('Classification Accuracy', fontsize=14, fontweight='bold')
    ax.set_ylim(0, 1)
    ax.tick_params(axis='x', rotation=15)
    for bar, val in zip(bars, results_df['Accuracy']):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
               f'{val:.3f}', ha='center', fontsize=10)

    # 2. ECE (lower is better)
    ax = axes[0, 1]
    bars = ax.bar(models, results_df['ECE'], color=colors, alpha=0.8, edgecolor='black')
    ax.set_ylabel('ECE', fontsize=12)
    ax.set_title('Expected Calibration Error (lower better)', fontsize=14, fontweight='bold')
    ax.tick_params(axis='x', rotation=15)
    for bar, val in zip(bars, results_df['ECE']):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
               f'{val:.4f}', ha='center', fontsize=10)

    # 3. NLL (lower is better)
    ax = axes[0, 2]
    bars = ax.bar(models, results_df['NLL'], color=colors, alpha=0.8, edgecolor='black')
    ax.set_ylabel('NLL', fontsize=12)
    ax.set_title('Negative Log-Likelihood (lower better)', fontsize=14, fontweight='bold')
    ax.tick_params(axis='x', rotation=15)
    for bar, val in zip(bars, results_df['NLL']):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
               f'{val:.3f}', ha='center', fontsize=10)

    # 4. Brier Score (lower is better)
    ax = axes[1, 0]
    bars = ax.bar(models, results_df['Brier_Score'], color=colors, alpha=0.8, edgecolor='black')
    ax.set_ylabel('Brier Score', fontsize=12)
    ax.set_title('Brier Score (lower better)', fontsize=14, fontweight='bold')
    ax.tick_params(axis='x', rotation=15)
    for bar, val in zip(bars, results_df['Brier_Score']):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
               f'{val:.4f}', ha='center', fontsize=10)
    """
    # 5. AUROC OOD (STD-based, higher is better)
    ax = axes[1, 1]
    bars = ax.bar(models, results_df['AUROC_OOD_STD'], color=colors, alpha=0.8, edgecolor='black')
    ax.set_ylabel('AUROC', fontsize=12)
    ax.set_title('AUROC OOD Detection (STD) (higher better)', fontsize=14, fontweight='bold')
    ax.set_ylim(0, 1)
    ax.tick_params(axis='x', rotation=15)
    for bar, val in zip(bars, results_df['AUROC_OOD_STD']):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
               f'{val:.3f}', ha='center', fontsize=10)

    # 6. STD Ratio (higher means better OOD detection)
    ax = axes[1, 2]
    bars = ax.bar(models, results_df['STD_Ratio'], color=colors, alpha=0.8, edgecolor='black')
    ax.set_ylabel('STD Ratio (OOD/ID)', fontsize=12)
    ax.set_title('Uncertainty Ratio (higher better)', fontsize=14, fontweight='bold')
    ax.tick_params(axis='x', rotation=15)
    for bar, val in zip(bars, results_df['STD_Ratio']):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05,
               f'{val:.2f}', ha='center', fontsize=10)
"""
    plt.suptitle('Model Comparison: Accuracy, Calibration & Uncertainty',
                fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


# =============================================================================
# UNCERTAINTY DISTRIBUTION VISUALIZATION
# =============================================================================

def plot_uncertainty_distributions(models_dict, X_id, X_ood, n_samples=50, save_path=None):
    """
    Plot uncertainty distributions for in-distribution vs OOD data.

    Parameters
    ----------
    models_dict : dict
        Dictionary of models to evaluate
    X_id : dict
        In-distribution data
    X_ood : dict
        Out-of-distribution data
    n_samples : int
        Number of MC samples for Bayesian models
    save_path : str, optional
        Path to save the figure
    """
    n_models = len(models_dict)
    fig, axes = plt.subplots(n_models, 2, figsize=(12, 3*n_models))

    if n_models == 1:
        axes = axes.reshape(1, -1)

    colors_id = '#2E86AB'  # Blue for ID
    colors_ood = '#C73E1D'  # Red for OOD

    for idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"Computing uncertainty for {model_name}...")

        # Get predictions
        if isinstance(model, DeepEnsemble):
            _, _, mi_id, std_id = ensemble_predictions_with_uncertainty(model, X_id)
            _, _, mi_ood, std_ood = ensemble_predictions_with_uncertainty(model, X_ood)
        else:
            #set_dropout_active(model, active=False)
            _, _, mi_id, std_id = mc_predictions_with_mi(model, X_id, n_samples)
            _, _, mi_ood, std_ood = mc_predictions_with_mi(model, X_ood, n_samples)

        # Plot STD distribution
        ax = axes[idx, 0]
        ax.hist(std_id, bins=50, alpha=0.6, color=colors_id,
               label='In-Distribution', density=True, edgecolor='white')
        ax.hist(std_ood, bins=50, alpha=0.6, color=colors_ood,
               label='OOD', density=True, edgecolor='white')
        ax.set_xlabel('Predictive Std', fontsize=10)
        ax.set_ylabel('Density', fontsize=10)
        ax.set_title(f'{model_name}: Std Distribution', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

        # Plot MI distribution
        ax = axes[idx, 1]
        # Clip MI for visualization (can be very large for some models)
        mi_id_clipped = np.clip(mi_id, 0, np.percentile(np.concatenate([mi_id, mi_ood]), 99))
        mi_ood_clipped = np.clip(mi_ood, 0, np.percentile(np.concatenate([mi_id, mi_ood]), 99))
        ax.hist(mi_id_clipped, bins=50, alpha=0.6, color=colors_id,
               label='In-Distribution', density=True, edgecolor='white')
        ax.hist(mi_ood_clipped, bins=50, alpha=0.6, color=colors_ood,
               label='OOD', density=True, edgecolor='white')
        ax.set_xlabel('Mutual Information', fontsize=10)
        ax.set_ylabel('Density', fontsize=10)
        ax.set_title(f'{model_name}: MI Distribution', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

    plt.suptitle('Uncertainty Distributions: In-Distribution vs OOD',
                fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


def plot_calibration_diagram(y_true, y_pred, n_bins=15, model_name="Model", save_path=None):
    """
    Plot reliability diagram (calibration curve).

    Parameters
    ----------
    y_true : np.ndarray
        True labels
    y_pred : np.ndarray
        Predicted probabilities
    n_bins : int
        Number of bins
    model_name : str
        Model name for title
    save_path : str, optional
        Path to save the figure
    """
    plt.figure(figsize=(8, 8))

    # Calculate calibration curve
    quantiles = np.linspace(0, 1, n_bins + 1)
    bin_boundaries = np.quantile(y_pred, quantiles)
    bin_boundaries[0] = 0.0
    bin_boundaries[-1] = 1.0

    bin_mids = []
    bin_accs = []

    for i in range(len(bin_boundaries) - 1):
        in_bin = (y_pred >= bin_boundaries[i]) & (y_pred < bin_boundaries[i+1])
        if in_bin.sum() > 0:
            bin_mids.append(y_pred[in_bin].mean())
            bin_accs.append(y_true[in_bin].mean())

    # Plot
    plt.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
    plt.bar(bin_mids, bin_accs, width=1.0/n_bins, alpha=0.7, edgecolor='black', label='Model')
    plt.scatter(bin_mids, bin_accs, color='red', zorder=5)

    plt.xlabel('Mean Predicted Probability', fontsize=12)
    plt.ylabel('Fraction of Positives', fontsize=12)
    plt.title(f'Calibration Diagram: {model_name}', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xlim(0, 1)
    plt.ylim(0, 1)

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


def plot_prediction_histograms(models_dict, X_test, n_samples=50, save_path=None):
    """
    Plot histograms of predicted probabilities for multiple models.

    Parameters
    ----------
    models_dict : dict
        Dictionary of models
    X_test : dict
        Test data
    n_samples : int
        Number of MC samples
    save_path : str, optional
        Path to save the figure
    """
    n_models = len(models_dict)
    fig, axes = plt.subplots(1, n_models, figsize=(4*n_models, 4))

    if n_models == 1:
        axes = [axes]

    for idx, (model_name, model) in enumerate(models_dict.items()):
        if isinstance(model, DeepEnsemble):
            pred, _, _, _ = ensemble_predictions_with_uncertainty(model, X_test)
        else:
            set_dropout_active(model, active=False)
            pred, _, _, _ = mc_predictions_with_mi(model, X_test, n_samples)

        ax = axes[idx]
        ax.hist(pred, bins=50, alpha=0.7, edgecolor='black')
        ax.set_xlabel('Predicted Probability', fontsize=10)
        ax.set_ylabel('Count', fontsize=10)
        ax.set_title(model_name, fontsize=12, fontweight='bold')
        ax.set_xlim(0, 1)
        ax.grid(True, alpha=0.3)

    plt.suptitle('Prediction Probability Distributions', fontsize=14, fontweight='bold')
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


# =============================================================================
# SUMMARY VISUALIZATION
# =============================================================================

def create_results_heatmap(results_df, metrics=None, save_path=None):
    """
    Create a heatmap visualization of evaluation metrics.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with evaluation results
    metrics : list, optional
        List of metrics to include (default: key metrics)
    save_path : str, optional
        Path to save the figure
    """
    if metrics is None:
        metrics = ['Accuracy', 'ECE', 'NLL', 'Brier_Score',
                   'AUROC_OOD_STD', 'STD_Ratio']

    # Filter to available metrics
    available_metrics = [m for m in metrics if m in results_df.columns]
    data = results_df[available_metrics]

    # Normalize for heatmap (0-1 scale)
    data_normalized = (data - data.min()) / (data.max() - data.min() + 1e-10)

    plt.figure(figsize=(12, 6))
    sns.heatmap(data_normalized, annot=data.round(3), fmt='',
                cmap='RdYlGn', center=0.5, linewidths=0.5)
    plt.title('Model Comparison Heatmap (Normalized)', fontsize=14, fontweight='bold')
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


# SINGULAR VALUE ANALYSIS 

def analyze_singular_values(model, rank=15, figsize=(16, 10), save_path=None):
    """
    Analyze and plot singular values of ALL learned weight matrices from deterministic model.

    understand the effective rank of the weight matrices and
    helps guide the choice of rank for low-rank Bayesian approximations.

    Parameters
    ----------
    model : tf.keras.Model
        Trained deterministic model
    rank : int
        Chosen rank for low-rank approximation (default: 15)
    figsize : tuple
        Figure size
    save_path : str, optional
        Path to save the figure

    Returns
    -------
    list
        Energy captured by chosen rank for each layer
    """
    # Extract all dense layers
    dense_layers = []
    for i, layer in enumerate(model.layers):
        if hasattr(layer, 'get_weights') and len(layer.get_weights()) > 0:
            weights = layer.get_weights()[0]
            if len(weights.shape) == 2:  # Only dense layers
                dense_layers.append((i, layer, weights))

    n_layers = len(dense_layers)
    print(f"Found {n_layers} dense layers in model")

    # Create subplots: 2 rows (singular values + cumulative energy) x n_layers columns
    fig, axes = plt.subplots(2, n_layers, figsize=figsize)
    if n_layers == 1:
        axes = axes.reshape(2, 1)

    all_energies = []

    for col_idx, (layer_idx, layer, W) in enumerate(dense_layers):
        # Compute SVD
        U, s, Vt = np.linalg.svd(W, full_matrices=False)

        # Compute cumulative energy
        total_energy = np.sum(s**2)
        cumulative_energy = np.cumsum(s**2) / total_energy * 100

        # Find energy captured by chosen rank
        energy_at_rank = cumulative_energy[min(rank-1, len(s)-1)]
        all_energies.append(energy_at_rank)

        # Plot 1: Singular values (top row)
        ax1 = axes[0, col_idx]
        ax1.plot(range(1, len(s)+1), s, 'o-', color='#264653', markersize=4, lw=2)
        if rank <= len(s):
            ax1.axvline(x=rank, color='#E63946', linestyle='--', lw=2,
                        label=f'r={rank}')
        ax1.set_xlabel('Index (i)', fontsize=10, fontweight='bold')
        ax1.set_ylabel('σᵢ(W*)', fontsize=10, fontweight='bold')
        ax1.set_title(f'Layer {layer_idx}: {W.shape}',
                      fontsize=11, fontweight='bold')
        ax1.legend(loc='upper right', fontsize=8)
        ax1.grid(alpha=0.3, linestyle='--')
        ax1.set_yscale('log')

        # Plot 2: Cumulative energy (bottom row)
        ax2 = axes[1, col_idx]
        ax2.plot(range(1, len(cumulative_energy)+1), cumulative_energy, 'o-',
                 color='#2A9D8F', markersize=4, lw=2)
        if rank <= len(s):
            ax2.axvline(x=rank, color='#E63946', linestyle='--', lw=2,
                        label=f'{energy_at_rank:.1f}%')
            ax2.axhline(y=energy_at_rank, color='#E63946', linestyle=':', lw=1.5, alpha=0.7)
        ax2.set_xlabel('Rank', fontsize=10, fontweight='bold')
        ax2.set_ylabel('Energy (%)', fontsize=10, fontweight='bold')
        ax2.set_title(f'r={rank}: {energy_at_rank:.1f}%',
                      fontsize=11, fontweight='bold')
        ax2.legend(loc='lower right', fontsize=8)
        ax2.grid(alpha=0.3, linestyle='--')
        ax2.set_ylim([0, 105])

        print(f"\nLayer {layer_idx} ({W.shape}):")
        print(f"  Total singular values: {len(s)}")
        print(f"  Energy at rank {rank}: {energy_at_rank:.2f}%")
        print(f"  Top 5 singular values: {s[:5]}")

    fig.suptitle(f'Singular Value Analysis: All Layers (Chosen rank r={rank})',
                 fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to '{save_path}'")

    plt.show()

    avg_energy = np.mean(all_energies)
    print(f"\n{'='*80}")
    print(f"SUMMARY: Rank r={rank} captures {avg_energy:.2f}% energy on average across all layers")
    print(f"{'='*80}")

    return all_energies


def analyze_singular_values_selected(model, rank=15, layer_names=None, layer_indices=None,
                                     figsize=(14, 8), save_path=None):
    """
    Analyze and plot singular values for SELECTED layers only (cleaner visualization).

    Uses LINEAR scale (or log) to make the decay visually obvious and dramatic.

    Parameters
    ----------
    model : tf.keras.Model
        Trained deterministic model
    rank : int
        Chosen rank for low-rank approximation (default: 15)
    layer_names : list of str, optional
        List of layer names to analyze (e.g., ['embedding', 'q_proj'])
        If None, uses layer_indices
    layer_indices : list of int, optional
        List of layer indices to analyze (e.g., [0, 5, 10])
        If None and layer_names is None, shows first 3 dense layers
    figsize : tuple
        Figure size (default: (14, 8))
    save_path : str, optional
        Path to save the figure

    Returns
    -------
    dict
        Dictionary with layer info and energy captured

    Example
    -------
    # Show embedding and first attention layer
    analyze_singular_values_selected(model, rank=12, layer_names=['embedding', 'q_proj'])

    # Show specific layer indices
    analyze_singular_values_selected(model, rank=12, layer_indices=[0, 5, 10])
    """
    # Extract all dense layers
    all_dense_layers = []
    for i, layer in enumerate(model.layers):
        if hasattr(layer, 'get_weights') and len(layer.get_weights()) > 0:
            weights = layer.get_weights()[0]
            if len(weights.shape) == 2:  # Only dense layers
                all_dense_layers.append((i, layer, weights))

    # Select layers to plot
    selected_layers = []

    if layer_names is not None:
        # Select by name
        for layer_idx, layer, W in all_dense_layers:
            layer_name = layer.name.lower()
            if any(name.lower() in layer_name for name in layer_names):
                selected_layers.append((layer_idx, layer, W))
                if len(selected_layers) >= 2:  # Max 2 layers for readability
                    break
    elif layer_indices is not None:
        # Select by index
        for layer_idx, layer, W in all_dense_layers:
            if layer_idx in layer_indices:
                selected_layers.append((layer_idx, layer, W))
                if len(selected_layers) >= 2:  # Max 2 layers for readability
                    break
    else:
        # Default: first 2 dense layers
        selected_layers = all_dense_layers[:2]

    n_selected = len(selected_layers)
    if n_selected == 0:
        print("No layers found matching the criteria!")
        return {}

    print(f"Plotting {n_selected} selected layer(s) out of {len(all_dense_layers)} total dense layers")

    # Create subplots: 2 rows x n_selected columns
    fig, axes = plt.subplots(2, n_selected, figsize=figsize)
    if n_selected == 1:
        axes = axes.reshape(2, 1)

    results = {}

    for col_idx, (layer_idx, layer, W) in enumerate(selected_layers):
        # Compute SVD
        U, s, Vt = np.linalg.svd(W, full_matrices=False)

        # Compute cumulative energy
        total_energy = np.sum(s**2)
        cumulative_energy = np.cumsum(s**2) / total_energy * 100

        # Find energy captured by chosen rank
        energy_at_rank = cumulative_energy[min(rank-1, len(s)-1)]

        # Plot 1: Singular values (top row) - LINEAR SCALE (shows decay more dramatically)
        ax1 = axes[0, col_idx]
        ax1.plot(range(1, len(s)+1), s, 'o-', color='#264653', markersize=5, lw=2.5)
        ax1.set_xlabel('Singular Value Index (i)', fontsize=12, fontweight='bold')
        ax1.set_ylabel('σᵢ (Linear Scale)', fontsize=12, fontweight='bold')
        ax1.set_title(f'{layer.name}\nShape: {W.shape}',
                      fontsize=12, fontweight='bold')
        ax1.grid(alpha=0.3, linestyle='--')

        # Plot 2: Cumulative energy (bottom row)
        ax2 = axes[1, col_idx]
        ax2.plot(range(1, len(cumulative_energy)+1), cumulative_energy, 'o-',
                 color='#2A9D8F', markersize=5, lw=2.5)
        ax2.set_xlabel('Rank', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Cumulative Energy (%)', fontsize=12, fontweight='bold')
        ax2.set_title('Cumulative Energy',
                      fontsize=12, fontweight='bold')
        ax2.grid(alpha=0.3, linestyle='--')
        ax2.set_ylim([0, 105])

        # Store results
        results[layer.name] = {
            'layer_idx': layer_idx,
            'shape': W.shape,
            'singular_values': s,
            'energy_at_rank': energy_at_rank,
            'total_singular_values': len(s)
        }

        print(f"\n{layer.name} (Layer {layer_idx}, Shape {W.shape}):")
        print(f"  Total singular values: {len(s)}")
        print(f"  Energy at rank {rank}: {energy_at_rank:.2f}%")
        print(f"  Singular value at rank {rank}: {s[min(rank-1, len(s)-1)]:.6f}")
        print(f"  Top 5 singular values: {s[:5]}")
        print(f"  Decay ratio (σ₁/σᵣ): {s[0]/s[min(rank-1, len(s)-1)]:.2f}x")

    fig.suptitle('Singular Value Analysis - Selected Layers',
                 fontsize=15, fontweight='bold', y=0.98)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to '{save_path}'")

    plt.show()

    return results


def plot_transformer_singular_values(model, rank=16, figsize=(18, 8), save_path=None):
    """
    Optimized singular value analysis for Transformer models.
    Automatically selects embedding and V projection layers only.

    Parameters
    ----------
    model : tf.keras.Model
        Trained baseline Transformer model
    rank : int
        Rank for low-rank approximation (default: 16)
    figsize : tuple
        Figure size (default: (18, 8))
    save_path : str, optional
        Path to save the figure

    Returns
    -------
    dict
        Results with layer info and energy captured
    """
    # Extract and categorize dense layers
    embedding_idx, value_idx = None, None

    for i, layer in enumerate(model.layers):
        if not hasattr(layer, 'get_weights') or len(layer.get_weights()) == 0:
            continue
        weights = layer.get_weights()[0]
        if len(weights.shape) != 2:
            continue

        name = layer.name.lower()
        if embedding_idx is None and ('emb' in name or 'embedding' in name):
            embedding_idx = i
        elif value_idx is None and ('value' in name or '_v' in name or 'v_proj' in name):
            value_idx = i

        if embedding_idx is not None and value_idx is not None:
            break

    # Fallback if not found
    if embedding_idx is None or value_idx is None:
        print("Warning: Could not find embedding or value layers, using first 2 dense layers")
        dense_layers = [(i, l) for i, l in enumerate(model.layers)
                        if hasattr(l, 'get_weights') and len(l.get_weights()) > 0
                        and len(l.get_weights()[0].shape) == 2]
        layer_indices = [dense_layers[0][0], dense_layers[1][0]] if len(dense_layers) >= 2 else [dense_layers[0][0]]
    else:
        layer_indices = [embedding_idx, value_idx]

    # Call the main plotting function
    return analyze_singular_values_selected(
        model,
        rank=rank,
        layer_indices=layer_indices,
        figsize=figsize,
        save_path=save_path
    )


# =============================================================================
# METRICS COMPARISON PLOTS
# =============================================================================

def plot_metrics_comparison(results_df, metrics_to_plot=None, highlight_models=None,
                           figsize=(16, 10), save_path=None):
    """
    Create grouped bar plot comparing all metrics across models.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with models as rows, metrics as columns
    metrics_to_plot : list, optional
        List of metrics to plot (default: all)
    highlight_models : list, optional
        List of model names to highlight (different color)
    figsize : tuple
        Figure size
    save_path : str, optional
        Path to save figure
    """
    if metrics_to_plot is None:
        metrics_to_plot = results_df.columns.tolist()

    if highlight_models is None:
        highlight_models = []

    n_metrics = len(metrics_to_plot)
    n_models = len(results_df)

    # Create subplots
    fig, axes = plt.subplots(1, n_metrics, figsize=figsize, sharey=False)
    if n_metrics == 1:
        axes = [axes]

    # Color scheme
    colors = []
    for model_name in results_df.index:
        if any(highlight in model_name for highlight in highlight_models):
            colors.append('#2E86AB')  # Blue for highlighted models
        else:
            colors.append('#A9A9A9')  # Gray for others

    # Plot each metric
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes[idx]
        values = results_df[metric].values
        positions = np.arange(n_models)

        bars = ax.barh(positions, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)

        # Highlight best model with gold border
        if 'ECE' in metric or 'NLL' in metric or 'Brier' in metric:
            best_idx = values.argmin()  # Lower is better
        else:
            best_idx = values.argmax()  # Higher is better
        bars[best_idx].set_edgecolor('gold')
        bars[best_idx].set_linewidth(3)

        # Add value labels
        for i, (pos, val) in enumerate(zip(positions, values)):
            fontweight = 'bold' if any(highlight in results_df.index[i] for highlight in highlight_models) else 'normal'
            fontsize = 10 if fontweight == 'bold' else 9
            ax.text(val + 0.01, pos, f'{val:.3f}', va='center', fontweight=fontweight, fontsize=fontsize)

        # Formatting
        ax.set_yticks(positions)
        ax.set_yticklabels(results_df.index, fontsize=10)
        ax.set_xlabel(metric.replace('_', ' '), fontsize=12, fontweight='bold')
        ax.set_xlim(0, values.max() * 1.15)
        ax.grid(axis='x', alpha=0.3, linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved to {save_path}")

    plt.show()


def plot_metrics_radar(results_df, metrics_to_plot=None, highlight_models=None,
                       figsize=(12, 12), save_path=None):
    """
    Create radar/spider plot to compare model performance across metrics.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with models as rows, metrics as columns
    metrics_to_plot : list, optional
        List of metrics to plot
    highlight_models : list, optional
        Models to highlight with thicker lines
    figsize : tuple
        Figure size
    save_path : str, optional
        Path to save figure
    """
    if metrics_to_plot is None:
        metrics_to_plot = results_df.columns.tolist()

    if highlight_models is None:
        highlight_models = []

    # Normalize metrics to [0, 1] for radar plot (higher = better after normalization)
    df_normalized = results_df[metrics_to_plot].copy()
    for col in df_normalized.columns:
        min_val = df_normalized[col].min()
        max_val = df_normalized[col].max()

        if max_val == min_val:
            df_normalized[col] = 0.5
        elif 'ECE' in col or 'NLL' in col or 'Brier' in col:
            # Lower is better - invert
            df_normalized[col] = (max_val - df_normalized[col]) / (max_val - min_val)
        else:
            # Higher is better
            df_normalized[col] = (df_normalized[col] - min_val) / (max_val - min_val)

    # Setup radar plot
    categories = [m.replace('_', '\n') for m in metrics_to_plot]
    N = len(categories)

    angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
    angles += angles[:1]

    fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(projection='polar'))

    # Plot each model
    for idx, model_name in enumerate(df_normalized.index):
        values = df_normalized.loc[model_name].values.tolist()
        values += values[:1]

        if any(highlight in model_name for highlight in highlight_models):
            linewidth = 3
            alpha = 0.9
            color = plt.cm.tab10(idx)
        else:
            linewidth = 1.5
            alpha = 0.4
            color = 'gray'

        ax.plot(angles, values, 'o-', linewidth=linewidth,
               label=model_name, alpha=alpha, color=color)
        ax.fill(angles, values, alpha=0.15, color=color)

    # Formatting
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, size=11)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], size=9)
    ax.grid(True, linestyle='--', alpha=0.5)

    plt.legend(loc='upper left', bbox_to_anchor=(1.02, 1.0),
               fontsize=10, frameon=False, ncol=1)
    plt.title('Model Performance Comparison\n(All metrics normalized to [0,1], higher=better)',
             size=14, fontweight='bold', pad=20)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved to {save_path}")

    plt.show()


def plot_uncertainty_comparison(results_df, highlight_models=None, figsize=(14, 6), save_path=None):
    """
    Create focused plot comparing uncertainty-related metrics.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with models as rows, metrics as columns
    highlight_models : list, optional
        Models to highlight
    figsize : tuple
        Figure size
    save_path : str, optional
        Path to save figure
    """
    if highlight_models is None:
        highlight_models = []

    # Select uncertainty metrics
    unc_metrics = [col for col in results_df.columns
                   if any(x in col for x in ['AUPR', 'AUROC_OOD', 'Uncertainty', 'Ratio'])]

    fig, axes = plt.subplots(1, 2, figsize=figsize)

    # Left: OOD detection metrics
    ood_metrics = [col for col in unc_metrics if 'OOD' in col or 'In_Domain' in col]
    ax = axes[0]

    x = np.arange(len(results_df))
    width = 0.25

    for i, metric in enumerate(ood_metrics[:3]):
        values = results_df[metric].values
        colors = ['#2E86AB' if any(h in model for h in highlight_models)
                 else '#A9A9A9' for model in results_df.index]

        ax.bar(x + i*width, values, width, label=metric.replace('_', ' '),
              color=colors, alpha=0.8, edgecolor='black')

    ax.set_xlabel('Models', fontsize=12, fontweight='bold')
    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('OOD Detection Performance', fontsize=13, fontweight='bold')
    ax.set_xticks(x + width)
    ax.set_xticklabels(results_df.index, rotation=45, ha='right', fontsize=9)
    ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1.0), loc='upper left', frameon=False)
    ax.grid(axis='y', alpha=0.3)

    # Right: Uncertainty ratios
    ax = axes[1]
    ratio_cols = [col for col in results_df.columns if 'Ratio' in col]
    if ratio_cols:
        values = results_df[ratio_cols[0]].values
        colors = ['#2E86AB' if any(h in model for h in highlight_models)
                 else '#A9A9A9' for model in results_df.index]

        bars = ax.barh(range(len(results_df)), values, color=colors,
                      alpha=0.8, edgecolor='black', linewidth=1.5)

        for i, val in enumerate(values):
            ax.text(val + 0.05, i, f'{val:.2f}x', va='center', fontweight='bold')

        ax.set_yticks(range(len(results_df)))
        ax.set_yticklabels(results_df.index, fontsize=10)
        ax.set_xlabel('Uncertainty Ratio (OOD/In-Domain)', fontsize=12, fontweight='bold')
        ax.set_title('OOD Uncertainty Amplification', fontsize=13, fontweight='bold')
        ax.axvline(1.0, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Baseline (1.0x)')
        ax.legend(bbox_to_anchor=(1.02, 1.0), loc='upper left', frameon=False)
        ax.grid(axis='x', alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved to {save_path}")

    plt.show()


def plot_focused_performance_radar(results_df, metrics_config=None, highlight_models=None,
                                   figsize=(12, 12), save_path=None):
    """
    Create a focused radar plot for key performance metrics.

    Specifically designed for comparing models on:
    - Accuracy (higher is better)
    - AUPR Success (higher is better)
    - AUROC OOD (higher is better)
    - AUPR In (higher is better)
    - AUPR Out (higher is better)
    - NLL (lower is better - will be inverted for display)

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with models as rows, metrics as columns
    metrics_config : dict, optional
        Dictionary mapping metric names to their properties:
        {'column_name': {'label': 'Display Label', 'higher_is_better': True/False}}
        If None, uses default configuration
    highlight_models : list, optional
        Models to highlight with thicker lines and brighter colors
    figsize : tuple
        Figure size (default: (12, 12))
    save_path : str, optional
        Path to save figure

    Example
    -------
    plot_focused_performance_radar(
        results_df,
        highlight_models=['Low-Rank BBB'],
        save_path='radar_plot.png'
    )
    """
    # Default metrics configuration
    if metrics_config is None:
        metrics_config = {
            'Accuracy': {'label': 'Accuracy', 'higher_is_better': True},
            'AUPR_Success_MI': {'label': 'AUPR\nSuccess', 'higher_is_better': True},
            'AUROC_OOD_MI': {'label': 'AUROC\nOOD', 'higher_is_better': True},
            'AUPR_In_MI': {'label': 'AUPR\nIn-Domain', 'higher_is_better': True},
            'AUPR_Out_MI': {'label': 'AUPR\nOOD', 'higher_is_better': True},
            'NLL': {'label': 'NLL\n(inverted)', 'higher_is_better': False},
        }

    # Filter to only metrics that exist in the dataframe
    available_metrics = {k: v for k, v in metrics_config.items() if k in results_df.columns}

    if len(available_metrics) == 0:
        print(f"Error: None of the requested metrics found in results_df")
        print(f"Available columns: {results_df.columns.tolist()}")
        return

    if highlight_models is None:
        highlight_models = []

    # Print raw data being used
    print("\n" + "="*80)
    print("RADAR PLOT: RAW DATA FROM results_df")
    print("="*80)
    for metric_name in available_metrics.keys():
        print(f"\n{metric_name}:")
        print(results_df[metric_name].to_string())
    print("="*80)

    # Normalize metrics to [0, 1] scale
    df_normalized = pd.DataFrame(index=results_df.index)

    for metric_name, config in available_metrics.items():
        values = results_df[metric_name].copy()
        min_val = values.min()
        max_val = values.max()

        if max_val == min_val:
            # All values are the same
            df_normalized[metric_name] = 0.5
        elif config['higher_is_better']:
            # Higher is better - normalize to [0, 1]
            df_normalized[metric_name] = (values - min_val) / (max_val - min_val)
        else:
            # Lower is better (like NLL) - invert so higher normalized = better
            df_normalized[metric_name] = (max_val - values) / (max_val - min_val)

    # Print normalized data
    print("\n" + "="*80)
    print("RADAR PLOT: NORMALIZED VALUES [0,1] (higher = better on plot)")
    print("="*80)
    print(df_normalized.to_string())
    print("="*80)

    # Setup radar plot
    categories = [available_metrics[m]['label'] for m in df_normalized.columns]
    N = len(categories)

    # Create angles for each axis
    angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
    angles += angles[:1]  # Complete the circle

    # Create figure
    fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(projection='polar'))

    # Color palette
    colors = plt.cm.tab10(np.linspace(0, 1, len(df_normalized)))

    # Plot each model
    for idx, model_name in enumerate(df_normalized.index):
        values = df_normalized.loc[model_name].values.tolist()
        values += values[:1]  # Complete the circle

        # Determine if model should be highlighted
        is_highlighted = any(highlight in model_name for highlight in highlight_models)

        if is_highlighted:
            linewidth = 3.5
            alpha_line = 1.0
            alpha_fill = 0.25
            marker_size = 10
            color = colors[idx]
        else:
            linewidth = 1.5
            alpha_line = 0.6
            alpha_fill = 0.08
            marker_size = 6
            color = colors[idx]

        # Plot line and markers
        ax.plot(angles, values, 'o-', linewidth=linewidth,
                label=model_name, alpha=alpha_line, color=color, markersize=marker_size)
        ax.fill(angles, values, alpha=alpha_fill, color=color)

    # Formatting
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, size=12, fontweight='bold')
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], size=10, color='gray')
    ax.grid(True, linestyle='--', alpha=0.4, linewidth=1)
    ax.set_facecolor('#f8f9fa')

    # Legend
    plt.legend(loc='upper left', bbox_to_anchor=(1.15, 1.05),
               fontsize=11, frameon=True, ncol=1, shadow=True,
               fancybox=True, borderpad=1)

    # Title with explanation
    title_text = 'Model Performance Radar Chart\n'
    title_text += '(All metrics normalized to [0,1], higher = better)'
    plt.title(title_text, size=15, fontweight='bold', pad=25)

    # Add grid circles for reference
    for value in [0.2, 0.4, 0.6, 0.8, 1.0]:
        ax.plot(angles, [value] * len(angles), 'gray', linewidth=0.5, linestyle=':', alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✓ Radar plot saved to: {save_path}")

    plt.show()

    # Print metric directions for reference
    print("\n" + "="*70)
    print("METRICS DIRECTION REFERENCE:")
    print("="*70)
    for metric_name, config in available_metrics.items():
        direction = "↑ Higher is better" if config['higher_is_better'] else "↓ Lower is better (inverted for plot)"
        actual_values = results_df[metric_name]
        print(f"  {metric_name:<20} {direction}")
        print(f"    Range: [{actual_values.min():.4f}, {actual_values.max():.4f}]")
    print("="*70)


def create_results_summary_table(results_df, highlight_models=None, save_path=None):
    """
    Create a formatted summary table highlighting best performances.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with results
    highlight_models : list, optional
        Models to highlight
    save_path : str, optional
        Path to save table as image

    Returns
    -------
    Styled DataFrame
    """
    if highlight_models is None:
        highlight_models = []

    def highlight_best(s):
        """Highlight best value in each column."""
        if 'ECE' in s.name or 'NLL' in s.name or 'Brier' in s.name:
            is_best = s == s.min()
        else:
            is_best = s == s.max()
        return ['background-color: lightgreen; font-weight: bold' if v else '' for v in is_best]

    def highlight_models_style(row):
        """Highlight specified models."""
        if any(h in row.name for h in highlight_models):
            return ['background-color: lightblue'] * len(row)
        return [''] * len(row)

    format_dict = {}
    for col in results_df.columns:
        if results_df[col].dtype in ['float64', 'float32', 'int64', 'int32']:
            format_dict[col] = "{:.4f}"

    styled = results_df.style\
        .apply(highlight_best, axis=0)\
        .apply(highlight_models_style, axis=1)\
        .format(format_dict)\
        .set_properties(**{'text-align': 'center'})\
        .set_table_styles([
            {'selector': 'th', 'props': [('font-weight', 'bold'), ('text-align', 'center')]},
            {'selector': 'td', 'props': [('padding', '8px')]}
        ])

    if save_path:
        print(f"Note: To save as image, use: styled.to_html() or screenshot")

    return styled


def plot_model_params(models, save_path=None, figsize=(12, 8)):
    """
    Create bar plot of parameter counts for all models.

    Parameters
    ----------
    models : dict
        Dictionary with model_name -> model (or list of models for ensemble)
    save_path : str, optional
        Path to save figure
    figsize : tuple
        Figure size
    """
    names = []
    params = []

    for name, model in models.items():
        names.append(name)
        if isinstance(model, list):
            params.append(model[0].count_params() * len(model))
        elif isinstance(model, DeepEnsemble):
            params.append(model.members[0].count_params() * len(model.members))
        else:
            params.append(model.count_params())

    fig, ax = plt.subplots(figsize=figsize)

    colors = ['#5B9BD5', '#ED7D31', '#70AD47', '#A5A5A5', '#FF6B6B', '#9B59B6']

    bars = ax.bar(range(len(names)), params, color=colors[:len(names)],
                   edgecolor='white', linewidth=1.5)

    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height):,}',
                ha='center', va='bottom', fontsize=12, fontweight='bold')

    ax.set_xticks(range(len(names)))
    ax.set_xticklabels(names, fontsize=11, rotation=15, ha='right')
    ax.set_ylabel('Total Parameters', fontsize=13, fontweight='bold')
    ax.set_title('Total Parameters - All Models', fontsize=16, fontweight='bold', pad=20)
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)

    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Parameter plot saved to {save_path}")

    plt.show()


# =============================================================================
# RELIABILITY/CALIBRATION DIAGRAMS
# =============================================================================

def plot_reliability_diagrams(models_dict, X_test, y_test, n_bins=10,
                              n_samples=50, figsize=(15, 10), save_path=None):
    """
    Plot reliability diagrams (calibration plots) for each model.

    Parameters
    ----------
    models_dict : dict
        Dictionary of trained models
    X_test : dict
        Test features
    y_test : np.ndarray
        Test labels
    n_bins : int
        Number of bins for calibration
    n_samples : int
        Number of MC samples for Bayesian models
    figsize : tuple
        Figure size
    save_path : str, optional
        Path to save figure
    """
    model_names = list(models_dict.keys())
    n_models = len(model_names)
    ncols = 3
    nrows = int(np.ceil(n_models / ncols))

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = axes.flatten() if n_models > 1 else [axes]

    for idx, model_name in enumerate(model_names):
        model = models_dict[model_name]

        # Get predictions
        if isinstance(model, DeepEnsemble):
            preds, _, _, _ = ensemble_predictions_with_uncertainty(model, X_test)
        else:
            set_dropout_active(model, active=False)
            preds, _, _, _ = mc_predictions_with_mi(model, X_test, n_samples)

        # Compute calibration
        bin_edges = np.linspace(0, 1, n_bins + 1)
        true_probs = []
        pred_probs = []
        counts = []

        for i in range(n_bins):
            mask = (preds >= bin_edges[i]) & (preds < bin_edges[i+1])
            if mask.sum() > 0:
                true_probs.append(y_test[mask].mean())
                pred_probs.append(preds[mask].mean())
                counts.append(mask.sum())
            else:
                true_probs.append(np.nan)
                pred_probs.append(np.nan)
                counts.append(0)

        ax = axes[idx]

        # Plot calibration curve
        valid_mask = ~np.isnan(true_probs)
        ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Perfect Calibration')
        ax.plot(np.array(pred_probs)[valid_mask], np.array(true_probs)[valid_mask],
                'o-', markersize=8, lw=2, color='#E63946', label='Model Calibration')

        # Add histogram of predictions
        ax2 = ax.twinx()
        ax2.hist(preds, bins=n_bins, alpha=0.3, color='gray', edgecolor='black')
        ax2.set_ylabel('Count', fontsize=9)
        ax2.tick_params(axis='y', labelsize=8)

        # Compute ECE
        valid_counts = np.array(counts)[valid_mask]
        valid_true = np.array(true_probs)[valid_mask]
        valid_pred = np.array(pred_probs)[valid_mask]
        ece = np.sum(np.abs(valid_true - valid_pred) * valid_counts) / np.sum(counts)

        ax.set_xlabel('Predicted Probability', fontsize=10)
        ax.set_ylabel('True Probability', fontsize=10)
        ax.set_title(f'{model_name}\nECE = {ece:.4f}', fontsize=11, fontweight='bold')
        ax.legend(loc='upper left', fontsize=8)
        ax.grid(alpha=0.3, linestyle='--')
        ax.set_xlim([-0.02, 1.02])
        ax.set_ylim([-0.02, 1.02])

    # Hide unused subplots
    for idx in range(n_models, len(axes)):
        axes[idx].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Reliability diagrams saved to '{save_path}'")

    plt.show()


# =============================================================================
# ROC CURVES WITH CONFIDENCE INTERVALS
# =============================================================================

def plot_roc_curves_ood(models_dict, X_test, X_ood, n_samples=50,
                        n_bootstrap=10, figsize=(12, 8), save_path=None):
    """
    Plot ROC curves for OOD detection with confidence intervals from bootstrap.

    Parameters
    ----------
    models_dict : dict
        Dictionary of trained models
    X_test : dict
        In-domain test features
    X_ood : dict
        Out-of-domain features
    n_samples : int
        Number of MC samples
    n_bootstrap : int
        Number of bootstrap iterations for confidence intervals
    figsize : tuple
        Figure size
    save_path : str, optional
        Path to save figure
    """
    colors = {
        'Baseline Transformer': '#6C757D',
        'Full-Rank BBB': '#E63946',
        'Low-Rank Gaussian': '#2A9D8F',
        'Low-Rank Laplace': '#277DA1',
        'Rank-1 Multiplicative': '#577590',
        'Deep Ensemble': '#F4A261',
    }

    # Fallback colors for models not in the predefined list
    default_colors = plt.cm.tab10(np.linspace(0, 1, len(models_dict)))

    fig, ax = plt.subplots(figsize=figsize)

    for idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"Computing ROC for {model_name}...")

        # Get predictions
        if isinstance(model, DeepEnsemble):
            _, _, mi_in, _ = ensemble_predictions_with_uncertainty(model, X_test)
            _, _, mi_ood, _ = ensemble_predictions_with_uncertainty(model, X_ood)
        else:
            set_dropout_active(model, active=False)
            _, _, mi_in, _ = mc_predictions_with_mi(model, X_test, n_samples)
            _, _, mi_ood, _ = mc_predictions_with_mi(model, X_ood, n_samples)

        # Create binary labels: 0 = in-domain, 1 = OOD
        y_true = np.concatenate([np.zeros(len(mi_in)), np.ones(len(mi_ood))])
        uncertainty_scores = np.concatenate([mi_in, mi_ood])

        # Bootstrap for confidence intervals
        aucs = []
        tprs_interp = []
        mean_fpr = np.linspace(0, 1, 100)

        for bootstrap_idx in range(n_bootstrap):
            indices = np.random.choice(len(y_true), size=len(y_true), replace=True)
            y_boot = y_true[indices]
            scores_boot = uncertainty_scores[indices]

            fpr, tpr, _ = roc_curve(y_boot, scores_boot)
            aucs.append(auc(fpr, tpr))

            tpr_interp = np.interp(mean_fpr, fpr, tpr)
            tpr_interp[0] = 0.0
            tprs_interp.append(tpr_interp)

        # Plot mean ROC curve
        mean_tpr = np.mean(tprs_interp, axis=0)
        mean_auc = np.mean(aucs)
        std_auc = np.std(aucs)
        std_tpr = np.std(tprs_interp, axis=0)

        color = colors.get(model_name, default_colors[idx])
        ax.plot(mean_fpr, mean_tpr, color=color, lw=2.5,
                label=f'{model_name} (AUC = {mean_auc:.3f} +/- {std_auc:.3f})')

        # Confidence interval
        tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
        tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
        ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color=color, alpha=0.2)

    # Diagonal reference line
    ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Random Classifier')

    ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    ax.set_title('ROC Curves: OOD Detection using Mutual Information', fontsize=14, fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fontsize=10)
    ax.grid(alpha=0.3, linestyle='--')
    ax.set_xlim([-0.02, 1.02])
    ax.set_ylim([-0.02, 1.02])

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"ROC curves saved to '{save_path}'")

    plt.show()


# =============================================================================
# WEIGHT CORRELATION HEATMAPS
# =============================================================================

def plot_weight_correlation_heatmap(models_dict, layer_idx=0, figsize=(16, 6), save_path=None):
    """
    Plot weight correlation heatmaps showing diagonal structure (full-rank) vs block structure (low-rank).

    Parameters
    ----------
    models_dict : dict
        Dictionary of trained models (should include 'Full-Rank BBB' and 'Low-Rank Gaussian')
    layer_idx : int
        Which Bayesian layer to visualize (0 = first Bayesian layer)
    figsize : tuple
        Figure size
    save_path : str, optional
        Path to save figure
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    # Full-Rank BBB
    if 'Full-Rank BBB' in models_dict:
        fullrank_model = models_dict['Full-Rank BBB']
        bayesian_layer_count = 0
        W_fullrank = None
        for layer in fullrank_model.layers:
            if hasattr(layer, 'w_mu'):
                if bayesian_layer_count == layer_idx:
                    W_fullrank = layer.w_mu.numpy()
                    print(f"Full-Rank: Found Bayesian layer with shape {W_fullrank.shape}")
                    break
                bayesian_layer_count += 1

        if W_fullrank is not None:
            corr_fullrank = np.corrcoef(W_fullrank.T)
            im1 = axes[0].imshow(corr_fullrank, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto', interpolation='nearest')
            axes[0].set_title('Full-Rank BBB: Diagonal Correlation Structure', fontsize=12, fontweight='bold')
            axes[0].set_xlabel('Output Neuron', fontsize=10)
            axes[0].set_ylabel('Output Neuron', fontsize=10)
            plt.colorbar(im1, ax=axes[0], label='Correlation', fraction=0.046)
        else:
            axes[0].text(0.5, 0.5, 'No Bayesian layer found', ha='center', va='center')
    else:
        axes[0].text(0.5, 0.5, 'Full-Rank BBB not in models_dict', ha='center', va='center')

    # Low-Rank Gaussian
    if 'Low-Rank Gaussian' in models_dict:
        lowrank_model = models_dict['Low-Rank Gaussian']
        bayesian_layer_count = 0
        W_lowrank = None
        for layer in lowrank_model.layers:
            if hasattr(layer, 'A_mu') and hasattr(layer, 'B_mu'):
                if bayesian_layer_count == layer_idx:
                    A = layer.A_mu.numpy()
                    B = layer.B_mu.numpy()
                    W_lowrank = A @ B.T
                    print(f"Low-Rank: Found Bayesian layer with shape {W_lowrank.shape} (rank={A.shape[1]})")
                    break
                bayesian_layer_count += 1

        if W_lowrank is not None:
            corr_lowrank = np.corrcoef(W_lowrank.T)
            im2 = axes[1].imshow(corr_lowrank, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto', interpolation='nearest')
            axes[1].set_title('Low-Rank: Block Correlation Structure', fontsize=12, fontweight='bold')
            axes[1].set_xlabel('Output Neuron', fontsize=10)
            axes[1].set_ylabel('Output Neuron', fontsize=10)
            plt.colorbar(im2, ax=axes[1], label='Correlation', fraction=0.046)
        else:
            axes[1].text(0.5, 0.5, 'No low-rank layer found', ha='center', va='center')
    else:
        axes[1].text(0.5, 0.5, 'Low-Rank Gaussian not in models_dict', ha='center', va='center')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Weight correlation heatmap saved to '{save_path}'")

    plt.show()


# =============================================================================
# RANK ANALYSIS PLOTS
# =============================================================================

def plot_rank_heatmaps(df, metrics=None, output_dir='figures'):
    """
    Create heatmaps showing optimal rank combinations for each metric.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame with rank search results (columns: r1, r2, kl_scale_parsed, metrics)
    metrics : list, optional
        Metrics to plot
    output_dir : str
        Output directory for saving figures
    """
    if metrics is None:
        metrics = ['NLL', 'ECE', 'AUROC_OOD_MI', 'AUPR_OOD_MI']

    # Filter to available metrics
    metrics = [m for m in metrics if m in df.columns]
    if not metrics:
        print("No valid metrics found in DataFrame")
        return

    metric_labels = {
        'NLL': 'NLL (lower is better)',
        'ECE': 'ECE (lower is better)',
        'AUROC_OOD_MI': 'AUROC OOD (higher is better)',
        'AUPR_OOD_MI': 'AUPR OOD (higher is better)'
    }

    if 'kl_scale_parsed' not in df.columns:
        print("Warning: kl_scale_parsed column not found, creating aggregated plot only")
        kl_scales = [None]
    else:
        kl_scales = sorted(df['kl_scale_parsed'].unique())

    fig, axes = plt.subplots(len(metrics), max(len(kl_scales), 1),
                              figsize=(5*max(len(kl_scales), 1), 4*len(metrics)))

    if len(kl_scales) == 1:
        axes = axes.reshape(-1, 1) if len(metrics) > 1 else np.array([[axes]])

    for i, metric in enumerate(metrics):
        for j, kl_scale in enumerate(kl_scales):
            ax = axes[i, j] if len(metrics) > 1 else axes[0, j]

            if kl_scale is not None:
                df_kl = df[df['kl_scale_parsed'] == kl_scale]
            else:
                df_kl = df

            pivot = df_kl.pivot_table(index='r1', columns='r2', values=metric, aggfunc='mean')

            if metric in ['NLL', 'ECE']:
                cmap = 'RdYlGn_r'
                best_idx = pivot.stack().idxmin()
                best_val = pivot.stack().min()
            else:
                cmap = 'RdYlGn'
                best_idx = pivot.stack().idxmax()
                best_val = pivot.stack().max()

            sns.heatmap(pivot, ax=ax, cmap=cmap, annot=True, fmt='.3f',
                       cbar_kws={'label': metric})

            title = f'{metric_labels.get(metric, metric)}'
            if kl_scale is not None:
                title += f'\nKL Scale={kl_scale}'
            title += f'\nBest: r1={best_idx[0]}, r2={best_idx[1]} ({best_val:.3f})'
            ax.set_title(title)
            ax.set_xlabel('Rank 2 (Layer 2)')
            ax.set_ylabel('Rank 1 (Layer 1)')

    plt.tight_layout()

    os.makedirs(output_dir, exist_ok=True)
    save_path = f'{output_dir}/rank_heatmaps.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {save_path}")


def plot_pareto_front(df, output_dir='figures'):
    """
    Create a Pareto front plot showing trade-off between calibration and OOD performance.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame with rank search results
    output_dir : str
        Output directory for saving figures
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    has_kl = 'kl_scale_parsed' in df.columns

    # Plot 1: NLL vs AUROC_OOD
    ax1 = axes[0]
    if has_kl:
        for kl_scale in sorted(df['kl_scale_parsed'].unique()):
            df_kl = df[df['kl_scale_parsed'] == kl_scale]
            ax1.scatter(df_kl['NLL'], df_kl['AUROC_OOD_MI'],
                       label=f'KL={kl_scale}', alpha=0.7, s=100)
            for _, row in df_kl.iterrows():
                ax1.annotate(f'({int(row["r1"])},{int(row["r2"])})',
                            (row['NLL'], row['AUROC_OOD_MI']),
                            fontsize=7, alpha=0.7)
    else:
        ax1.scatter(df['NLL'], df['AUROC_OOD_MI'], alpha=0.7, s=100)
        for _, row in df.iterrows():
            ax1.annotate(f'({int(row["r1"])},{int(row["r2"])})',
                        (row['NLL'], row['AUROC_OOD_MI']),
                        fontsize=7, alpha=0.7)

    ax1.set_xlabel('NLL (lower is better)', fontsize=11)
    ax1.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax1.set_title('NLL vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    if has_kl:
        ax1.legend()
    ax1.grid(alpha=0.3)

    # Plot 2: ECE vs AUROC_OOD
    ax2 = axes[1]
    if has_kl:
        for kl_scale in sorted(df['kl_scale_parsed'].unique()):
            df_kl = df[df['kl_scale_parsed'] == kl_scale]
            ax2.scatter(df_kl['ECE'], df_kl['AUROC_OOD_MI'],
                       label=f'KL={kl_scale}', alpha=0.7, s=100)
            for _, row in df_kl.iterrows():
                ax2.annotate(f'({int(row["r1"])},{int(row["r2"])})',
                            (row['ECE'], row['AUROC_OOD_MI']),
                            fontsize=7, alpha=0.7)
    else:
        ax2.scatter(df['ECE'], df['AUROC_OOD_MI'], alpha=0.7, s=100)
        for _, row in df.iterrows():
            ax2.annotate(f'({int(row["r1"])},{int(row["r2"])})',
                        (row['ECE'], row['AUROC_OOD_MI']),
                        fontsize=7, alpha=0.7)

    ax2.set_xlabel('ECE (lower is better)', fontsize=11)
    ax2.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax2.set_title('ECE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    if has_kl:
        ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout()

    os.makedirs(output_dir, exist_ok=True)
    save_path = f'{output_dir}/rank_pareto_tradeoff.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {save_path}")


# =============================================================================
# MODULE INFO
# =============================================================================

if __name__ == "__main__":
    print("Visualization module loaded successfully!")
    print("\nAvailable functions:")
    print("\n  Training & Convergence:")
    print("  - plot_icml_convergence: Plot validation accuracy trajectories")
    print("  - plot_training_loss: Plot training loss trajectories")
    print("  - plot_ensemble_training: Plot ensemble member training curves")
    print("\n  Model Comparison:")
    print("  - plot_model_comparison: Comprehensive model comparison charts")
    print("  - plot_metrics_comparison: Grouped bar plot comparing metrics")
    print("  - plot_metrics_radar: Radar/spider plot for model comparison")
    print("  - plot_uncertainty_comparison: Uncertainty-focused comparison")
    print("  - create_results_summary_table: Styled summary table")
    print("  - plot_model_params: Bar plot of parameter counts")
    print("  - create_results_heatmap: Heatmap of evaluation metrics")
    print("\n  Uncertainty & Calibration:")
    print("  - plot_uncertainty_distributions: ID vs OOD uncertainty histograms")
    print("  - plot_calibration_diagram: Reliability diagram")
    print("  - plot_reliability_diagrams: Multiple model reliability diagrams")
    print("  - plot_prediction_histograms: Prediction probability distributions")
    print("\n  ROC & OOD Detection:")
    print("  - plot_roc_curves_ood: ROC curves with confidence intervals")
    print("\n  Singular Value & Rank Analysis:")
    print("  - analyze_singular_values: SVD analysis for ALL layers (can be crowded)")
    print("  - analyze_singular_values_selected: SVD for SELECTED layers (cleaner, recommended)")
    print("  - plot_rank_heatmaps: Heatmaps for rank grid search")
    print("  - plot_pareto_front: Trade-off visualization")
    print("\n  Weight Analysis:")
    print("  - plot_weight_correlation_heatmap: Weight correlation heatmaps")
    print("\n  MC Convergence Analysis:")
    print("  - analyze_mc_convergence: Analyze metric convergence with MC samples")


# =============================================================================
# MC CONVERGENCE ANALYSIS
# =============================================================================

def analyze_mc_convergence(models_dict, X_test, y_test, X_ood, y_ood, 
                           mc_samples_range=None, metrics_to_plot=None,
                           save_path=None):
    """
    Analyze how metrics converge with increasing MC_samples.
    
    This function evaluates models at different MC sample counts to:
    1. Verify Deep Ensemble stability (should be flat - no MC dependence)
    2. Show Bayesian model convergence (should stabilize with more samples)
    3. Identify optimal MC sample count for each model
    
    Parameters
    ----------
    models_dict : dict
        Dictionary of models to analyze {model_name: model}
    X_test, y_test : dict, array
        In-distribution test data
    X_ood, y_ood : dict, array
        Out-of-distribution test data
    mc_samples_range : list, optional
        List of MC sample counts to test (default: [10, 20, 50, 100, 200, 300, 500])
    metrics_to_plot : list, optional
        List of metrics to plot (default: ['Accuracy', 'AUROC_OOD_STD', 'AUROC_OOD_MI', 'NLL'])
    save_path : str, optional
        Path to save the plot
        
    Returns
    -------
    dict
        Dictionary with convergence data: {model_name: {metric: [values]}}
        
    Example
    -------
    >>> convergence_data = analyze_mc_convergence(
    ...     models_dict={'Low-Rank BBB': lowrank_model, 'Deep Ensemble': ensemble},
    ...     X_test=X_test_dict,
    ...     y_test=y_test_arr,
    ...     X_ood=X_ood_dict,
    ...     y_ood=y_ood_arr,
    ...     mc_samples_range=[10, 50, 100, 200, 500],
    ...     save_path='mc_convergence.png'
    ... )
    """
    from modules.evaluation import evaluate_all_models_with_optimized_ece
    
    if mc_samples_range is None:
        mc_samples_range = [10, 20, 50, 100, 200, 300, 500]
    
    if metrics_to_plot is None:
        metrics_to_plot = ['Accuracy', 'AUROC_OOD_STD', 'AUROC_OOD_MI', 'NLL']
    
    print("="*80)
    print("MC SAMPLES CONVERGENCE ANALYSIS")
    print("="*80)
    print(f"Testing MC samples: {mc_samples_range}")
    print(f"Metrics: {metrics_to_plot}")
    print("="*80)
    
    # Store results
    convergence_data = {model_name: {metric: [] for metric in metrics_to_plot} 
                        for model_name in models_dict.keys()}
    
    # Evaluate each model at different MC sample counts
    for n_samples in mc_samples_range:
        print(f"\n{'='*80}")
        print(f"Testing with n_samples = {n_samples}")
        print(f"{'='*80}")
        
        for model_name, model in models_dict.items():
            print(f"\nEvaluating {model_name}...")
            
            metrics = evaluate_all_models_with_optimized_ece (
                model=model,
                X_test=X_test,
                y_test=y_test,
                X_ood=X_ood,
                y_ood=y_ood,
                model_name=model_name,
                n_samples=n_samples
            )
            
            # Store metrics
            for metric in metrics_to_plot:
                if metric in metrics:
                    convergence_data[model_name][metric].append(metrics[metric])
                    print(f"  {metric}: {metrics[metric]:.6f}")
    
    # Plot convergence
    n_metrics = len(metrics_to_plot)
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes = axes.flatten()
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(models_dict)))
    
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes[idx]
        
        for model_idx, (model_name, data) in enumerate(convergence_data.items()):
            values = data[metric]
            
            # Check if model is Deep Ensemble (should be flat)
            is_ensemble = 'Ensemble' in model_name
            linestyle = '--' if is_ensemble else '-'
            linewidth = 3 if is_ensemble else 2
            marker = 's' if is_ensemble else 'o'
            
            ax.plot(mc_samples_range, values, 
                   marker=marker, linestyle=linestyle, linewidth=linewidth,
                   color=colors[model_idx], label=model_name, 
                   markersize=8, alpha=0.8)
        
        ax.set_xlabel('Number of MC Samples', fontsize=12, fontweight='bold')
        ax.set_ylabel(metric, fontsize=12, fontweight='bold')
        ax.set_title(f'{metric} Convergence', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=9, loc='best')
        
        # Add annotation for expected behavior
        for model_name in convergence_data.keys():
            if 'Ensemble' in model_name:
                ensemble_values = convergence_data[model_name][metric]
                variance = np.var(ensemble_values)
                if variance < 1e-10:
                    ax.text(0.05, 0.95, '✓ Deep Ensemble: STABLE (as expected)', 
                           transform=ax.transAxes, fontsize=9, 
                           verticalalignment='top', color='green',
                           bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.3))
                else:
                    ax.text(0.05, 0.95, '✗ Deep Ensemble: UNSTABLE (BUG!)', 
                           transform=ax.transAxes, fontsize=9, 
                           verticalalignment='top', color='red',
                           bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.3))
                break
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\n✓ Plot saved to {save_path}")
    
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("CONVERGENCE SUMMARY")
    print("="*80)
    
    for model_name, data in convergence_data.items():
        print(f"\n{model_name}:")
        for metric in metrics_to_plot:
            values = np.array(data[metric])
            min_val = values.min()
            max_val = values.max()
            std_val = values.std()
            range_val = max_val - min_val
            cv = (std_val / np.abs(values.mean())) * 100 if values.mean() != 0 else 0
            
            # Check stability
            is_stable = cv < 0.01  # Less than 0.01% CV
            status = "✓ STABLE" if is_stable else "⚠ VARYING"
            
            print(f"  {metric:20s}: min={min_val:.6f}, max={max_val:.6f}, "
                  f"range={range_val:.6f}, CV={cv:.4f}%  {status}")
    
    return convergence_data
def plot_pareto_front(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create a Pareto front plot showing trade-off between calibration (NLL/ECE)
    and OOD performance (AUROC_OOD).
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Plot 1: NLL vs AUROC_OOD
    ax1 = axes[0]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        scatter = ax1.scatter(df_kl['nll'], df_kl['auroc'],
                             label=f'KL={kl_scale}', alpha=0.7, s=100)

        # Annotate points with rank values
        for _, row in df_kl.iterrows():
            ax1.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['nll'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax1.set_xlabel('NLL (lower is better)', fontsize=11)
    ax1.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax1.set_title('NLL vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)

    # Highlight ideal region (lower-left for NLL, upper for AUROC)
    ax1.axvline(df['nll'].quantile(0.25), color='green', linestyle='--', alpha=0.3, label='25th percentile NLL')
    ax1.axhline(df['auroc'].quantile(0.75), color='green', linestyle='--', alpha=0.3)

    # Plot 2: ECE vs AUROC_OOD
    ax2 = axes[1]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        scatter = ax2.scatter(df_kl['ece'], df_kl['auroc'],
                             label=f'KL={kl_scale}', alpha=0.7, s=100)

        for _, row in df_kl.iterrows():
            ax2.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['ece'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax2.set_xlabel('ECE (lower is better)', fontsize=11)
    ax2.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax2.set_title('ECE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'figures/rank_pareto_tradeoff.png', dpi=300, bbox_inches='tight')
    plt.show()


def plot_pareto_front_performance(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create Pareto front plots showing trade-off between predictive performance
    (MAE/RMSE) and OOD performance (AUROC_OOD).
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Plot 1: MAE vs AUROC_OOD
    ax1 = axes[0]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        ax1.scatter(df_kl['mae'], df_kl['auroc'],
                    label=f'KL={kl_scale}', alpha=0.7, s=100)

        for _, row in df_kl.iterrows():
            ax1.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['mae'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax1.set_xlabel('MAE (lower is better)', fontsize=11)
    ax1.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax1.set_title('MAE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)
    ax1.axvline(df['mae'].quantile(0.25), color='green', linestyle='--', alpha=0.3)
    ax1.axhline(df['auroc'].quantile(0.75), color='green', linestyle='--', alpha=0.3)

    # Plot 2: RMSE vs AUROC_OOD
    ax2 = axes[1]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        ax2.scatter(df_kl['rmse'], df_kl['auroc'],
                    label=f'KL={kl_scale}', alpha=0.7, s=100)

        for _, row in df_kl.iterrows():
            ax2.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['rmse'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax2.set_xlabel('RMSE (lower is better)', fontsize=11)
    ax2.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax2.set_title('RMSE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(alpha=0.3)
    ax2.axvline(df['rmse'].quantile(0.25), color='green', linestyle='--', alpha=0.3)
    ax2.axhline(df['auroc'].quantile(0.75), color='green', linestyle='--', alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{output_dir}/rank_pareto_performance.png', dpi=300, bbox_inches='tight')
    plt.show()

