"""
Visualization Functions for Bayesian LSTM

This module contains all plotting-related functions:
- Training history plots
- Comparison bar plots
- Calibration curves
- Sharpness curves
- Selective prediction plots
- Radar plots
- Singular value analysis plots
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re


# ==============================================================================
# Multi-Seed Robustness Plots
# ==============================================================================

def plot_multi_seed_boxplots(aggregated_results,
                              model_names,
                              seeds,
                              metrics=None,
                              figsize_per_metric=(4, 5),
                              palette='Set2',
                              save_path=None):
    """
    Create box plots with overlaid points showing metric distributions across seeds.

    Args:
        aggregated_results: Dictionary with structure:
                           {model_name: {metric_name: {'mean': float, 'std': float, 'values': list}}}
        model_names: List of model names to include
        seeds: List of seed values used
        metrics: List of metrics to plot. Default: ['MAE', 'NLL', 'ECE', 'PICP', 'AUROC_OOD']
        figsize_per_metric: Tuple (width, height) per metric subplot
        palette: Seaborn color palette name
        save_path: Path to save figure (optional)

    Returns:
        fig, axes: Matplotlib figure and axes objects
    """
    import pandas as pd

    sns.set_style("whitegrid")

    # Default metrics
    if metrics is None:
        metrics = ['MAE', 'NLL', 'ECE', 'PICP', 'AUROC_OOD']

    # Check which metrics are available
    available_metrics = []
    if aggregated_results:
        first_model = list(aggregated_results.keys())[0]
        available_metrics = [m for m in metrics if m in aggregated_results[first_model]]

    if not available_metrics:
        print("No metrics available for visualization")
        return None, None

    n_metrics = len(available_metrics)
    fig, axes = plt.subplots(1, n_metrics,
                              figsize=(figsize_per_metric[0] * n_metrics, figsize_per_metric[1]))
    if n_metrics == 1:
        axes = [axes]

    for idx, metric in enumerate(available_metrics):
        ax = axes[idx]

        # Prepare data
        plot_data = []
        for model_name in model_names:
            if model_name in aggregated_results and metric in aggregated_results[model_name]:
                stats = aggregated_results[model_name][metric]
                for seed_idx, value in enumerate(stats['values']):
                    if not np.isnan(value):
                        plot_data.append({
                            'Model': model_name.replace(' ', '\n'),
                            'Value': value,
                            'Seed': seeds[seed_idx] if seed_idx < len(seeds) else seed_idx
                        })

        if plot_data:
            df_plot = pd.DataFrame(plot_data)

            # Box plot with points
            sns.boxplot(data=df_plot, x='Model', y='Value', ax=ax, palette=palette)
            sns.stripplot(data=df_plot, x='Model', y='Value', ax=ax,
                         color='black', alpha=0.6, size=6, jitter=True)

            ax.set_title(metric.replace('_', ' '), fontsize=12, fontweight='bold')
            ax.set_xlabel('')
            ax.set_ylabel(metric.replace('_', ' '), fontsize=10)
            ax.tick_params(axis='x', rotation=45, labelsize=8)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved to: {save_path}")

    plt.show()
    return fig, axes


# ==============================================================================
# Training History Plots
# ==============================================================================

def plot_training_history(history, title="Training History", save_path=None):
    """
    Plot training and validation loss curves.
    Args:
        history: Keras History object
        title: Plot title
        save_path: Optional path to save the figure
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    # Loss
    axes[0].plot(history.history['loss'], label='Train Loss', linewidth=2)
    axes[0].plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('MSE Loss', fontsize=12)
    axes[0].set_title(f'{title} - Loss', fontsize=14)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    # MAE
    if 'mae' in history.history:
        axes[1].plot(history.history['mae'], label='Train MAE', linewidth=2)
        axes[1].plot(history.history['val_mae'], label='Val MAE', linewidth=2)
        axes[1].set_xlabel('Epoch', fontsize=12)
        axes[1].set_ylabel('MAE', fontsize=12)
        axes[1].set_title(f'{title} - MAE', fontsize=14)
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


def plot_ensemble_training_curves(ensemble_histories, M, save_path=None):
    """
    Plot Deep Ensemble training curves.
    Args:
        ensemble_histories: List of Keras History objects
        M: Number of ensemble members
        save_path: Optional path to save the figure
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    # Plot val loss for each member
    for i, history in enumerate(ensemble_histories):
        axes[0].plot(history.history['val_loss'], label=f'Member {i+1}', alpha=0.7)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Validation Loss (MSE)', fontsize=12)
    axes[0].set_title(f'Deep Ensemble ({M} Members) - Val Loss', fontsize=14)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    # Plot val MAE for each member
    for i, history in enumerate(ensemble_histories):
        axes[1].plot(history.history['val_mae'], label=f'Member {i+1}', alpha=0.7)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Validation MAE', fontsize=12)
    axes[1].set_title(f'Deep Ensemble ({M} Members) - Val MAE', fontsize=14)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


# ==============================================================================
# Comparison Bar Plots
# ==============================================================================

def plot_point_prediction_comparison(results, model_names, colors):
    """
    Plot bar charts comparing point prediction metrics across models.
    Args:
        results: Dictionary with results per model
        model_names: List of model names to compare
        colors: List of colors for each model
    """
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    axes = axes.flatten()

    # Plot 1: MAE
    maes = [results[m]['test']['MAE'] for m in model_names]
    bars = axes[0].bar(range(len(model_names)), maes, color=colors, alpha=0.7)
    axes[0].set_xticks(range(len(model_names)))
    axes[0].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[0].set_ylabel('MAE (ug/m3)', fontsize=12)
    axes[0].set_title('Mean Absolute Error (lower is better)', fontsize=13, fontweight='bold')
    axes[0].grid(True, alpha=0.3, axis='y')
    for bar, val in zip(bars, maes):
        axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                    f'{val:.2f}', ha='center', va='bottom', fontsize=10)

    # Plot 2: RMSE
    rmses = [results[m]['test']['RMSE'] for m in model_names]
    bars = axes[1].bar(range(len(model_names)), rmses, color=colors, alpha=0.7)
    axes[1].set_xticks(range(len(model_names)))
    axes[1].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[1].set_ylabel('RMSE (ug/m3)', fontsize=12)
    axes[1].set_title('Root Mean Squared Error (lower is better)', fontsize=13, fontweight='bold')
    axes[1].grid(True, alpha=0.3, axis='y')
    for bar, val in zip(bars, rmses):
        axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                    f'{val:.2f}', ha='center', va='bottom', fontsize=10)

    # Plot 3: R2
    r2s = [results[m]['test']['R2'] for m in model_names]
    bars = axes[2].bar(range(len(model_names)), r2s, color=colors, alpha=0.7)
    axes[2].set_xticks(range(len(model_names)))
    axes[2].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[2].set_ylabel('R2 Score', fontsize=12)
    axes[2].set_title('Coefficient of Determination (higher is better)', fontsize=13, fontweight='bold')
    axes[2].grid(True, alpha=0.3, axis='y')
    axes[2].set_ylim([0, 1])
    for bar, val in zip(bars, r2s):
        axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                    f'{val:.3f}', ha='center', va='bottom', fontsize=10)

    # Turn off unused axes
    axes[3].axis('off')
    axes[4].axis('off')
    axes[5].axis('off')

    plt.tight_layout()
    plt.show()


def plot_uncertainty_comparison(uncertainty_results, model_names, colors, confidence_level=0.95):
    """
    Plot bar charts comparing uncertainty metrics across models.
    Args:
        uncertainty_results: Dictionary with uncertainty results per model
        model_names: List of model names to compare
        colors: List of colors for each model
        confidence_level: Confidence level for PICP target line
    """
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    axes = axes.flatten()

    # Plot 1: NLL
    nlls = [uncertainty_results[m]['NLL'] for m in model_names]
    bars = axes[0].bar(range(len(model_names)), nlls, color=colors, alpha=0.7)
    axes[0].set_xticks(range(len(model_names)))
    axes[0].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[0].set_ylabel('NLL', fontsize=12)
    axes[0].set_title('Negative Log-Likelihood (lower is better)', fontsize=13, fontweight='bold')
    axes[0].grid(True, alpha=0.3, axis='y')
    for bar, val in zip(bars, nlls):
        axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                     f'{val:.3f}', ha='center', va='bottom', fontsize=10)

    # Plot 2: CRPS
    crpss = [uncertainty_results[m]['CRPS'] for m in model_names]
    bars = axes[1].bar(range(len(model_names)), crpss, color=colors, alpha=0.7)
    axes[1].set_xticks(range(len(model_names)))
    axes[1].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[1].set_ylabel('CRPS', fontsize=12)
    axes[1].set_title('Continuous Ranked Probability Score (lower is better)', fontsize=13, fontweight='bold')
    axes[1].grid(True, alpha=0.3, axis='y')
    for bar, val in zip(bars, crpss):
        axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                     f'{val:.3f}', ha='center', va='bottom', fontsize=10)

    # Plot 3: ECE
    eces = [uncertainty_results[m]['ECE'] for m in model_names]
    bars = axes[2].bar(range(len(model_names)), eces, color=colors, alpha=0.7)
    axes[2].set_xticks(range(len(model_names)))
    axes[2].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[2].set_ylabel('ECE', fontsize=12)
    axes[2].set_title('Expected Calibration Error (lower is better)', fontsize=13, fontweight='bold')
    axes[2].grid(True, alpha=0.3, axis='y')
    for bar, val in zip(bars, eces):
        axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                     f'{val:.4f}', ha='center', va='bottom', fontsize=10)

    # Plot 4: PICP
    picps = [uncertainty_results[m]['PICP'] for m in model_names]
    bars = axes[3].bar(range(len(model_names)), picps, color=colors, alpha=0.7)
    axes[3].axhline(y=confidence_level, color='red', linestyle='--', linewidth=2, label='Target')
    axes[3].set_xticks(range(len(model_names)))
    axes[3].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[3].set_ylabel('PICP', fontsize=12)
    axes[3].set_title(f'Prediction Interval Coverage ({confidence_level*100:.0f}% PI) (higher is better)',
                      fontsize=13, fontweight='bold')
    axes[3].set_ylim([0, 1])
    axes[3].grid(True, alpha=0.3, axis='y')
    axes[3].legend(loc='lower right')
    for bar, val in zip(bars, picps):
        axes[3].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                     f'{val:.3f}', ha='center', va='bottom', fontsize=10)

    # Plot 5: MPIW
    mpiws = [uncertainty_results[m]['MPIW'] for m in model_names]
    bars = axes[4].bar(range(len(model_names)), mpiws, color=colors, alpha=0.7)
    axes[4].set_xticks(range(len(model_names)))
    axes[4].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[4].set_ylabel('MPIW (ug/m3)', fontsize=12)
    axes[4].set_title('Mean Prediction Interval Width (lower is better)', fontsize=13, fontweight='bold')
    axes[4].grid(True, alpha=0.3, axis='y')
    for bar, val in zip(bars, mpiws):
        axes[4].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                     f'{val:.2f}', ha='center', va='bottom', fontsize=10)

    # Plot 6: Mean Uncertainty
    mean_stds = [uncertainty_results[m]['mean_std'] for m in model_names]
    bars = axes[5].bar(range(len(model_names)), mean_stds, color=colors, alpha=0.7)
    axes[5].set_xticks(range(len(model_names)))
    axes[5].set_xticklabels([m.replace(' ', '\n') for m in model_names], fontsize=9)
    axes[5].set_ylabel('Mean Std (ug/m3)', fontsize=12)
    axes[5].set_title('Average Predicted Uncertainty', fontsize=13, fontweight='bold')
    axes[5].grid(True, alpha=0.3, axis='y')
    for bar, val in zip(bars, mean_stds):
        axes[5].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                     f'{val:.2f}', ha='center', va='bottom', fontsize=10)

    plt.tight_layout()
    plt.show()


# ==============================================================================
# Calibration Curves
# ==============================================================================

def plot_calibration_curves(uncertainty_results, y_test_original, model_names, colors, save_path=None):
    """
    Plot calibration curves for all models.
    Args:
        uncertainty_results: Dictionary with uncertainty results per model
        y_test_original: True values in original scale
        model_names: List of model names
        colors: List of colors
        save_path: Optional path to save the figure
    """
    from modules.evaluation import compute_calibration_curve

    fig, axes = plt.subplots(1, len(model_names), figsize=(4.5 * len(model_names), 5))
    if len(model_names) == 1:
        axes = [axes]

    for idx, model_name in enumerate(model_names):
        # Get predictions
        y_pred_mean = uncertainty_results[model_name]['pred_mean']
        y_pred_std = uncertainty_results[model_name]['pred_std']
        # Compute calibration curve
        expected_cov, observed_cov = compute_calibration_curve(
            y_test_original, y_pred_mean, y_pred_std, num_bins=15
        )
        # Store for later
        uncertainty_results[model_name]['calibration_curve'] = {
            'expected': expected_cov,
            'observed': observed_cov,
        }
        # Plot
        ax = axes[idx]
        ax.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Perfect Calibration')
        ax.plot(expected_cov, observed_cov, 'o-', color=colors[idx],
                linewidth=2, markersize=6, label=model_name)
        ax.fill_between([0, 1], [0, 1], alpha=0.1, color='gray')
        ax.set_xlabel('Expected Coverage', fontsize=12)
        ax.set_ylabel('Observed Coverage', fontsize=12)
        ax.set_title(f'{model_name}', fontsize=13, fontweight='bold')
        ax.legend(loc='lower right')
        ax.grid(True, alpha=0.3)
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 1])
        ax.set_aspect('equal')
        # Compute calibration AUC (area between curves)
        calib_auc = np.trapz(np.abs(observed_cov - expected_cov), expected_cov)
        print(f"  {model_name} Calibration Error (AUC): {calib_auc:.4f}")
        uncertainty_results[model_name]['calibration_auc'] = calib_auc

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


def plot_sharpness_curves(uncertainty_results, model_names, colors, save_path=None):
    """
    Plot sharpness curves (distribution of predicted uncertainties).
    Args:
        uncertainty_results: Dictionary with uncertainty results per model
        model_names: List of model names
        colors: List of colors
        save_path: Optional path to save the figure
    """
    fig, axes = plt.subplots(1, len(model_names), figsize=(4.5 * len(model_names), 5))
    if len(model_names) == 1:
        axes = [axes]

    for idx, model_name in enumerate(model_names):
        y_pred_std = uncertainty_results[model_name]['pred_std']
        ax = axes[idx]
        ax.hist(y_pred_std, bins=30, color=colors[idx], alpha=0.7, edgecolor='black')
        ax.axvline(np.mean(y_pred_std), color='red', linestyle='--',
                   linewidth=2, label=f'Mean: {np.mean(y_pred_std):.2f}')
        ax.axvline(np.median(y_pred_std), color='blue', linestyle='--',
                   linewidth=2, label=f'Median: {np.median(y_pred_std):.2f}')
        ax.set_xlabel('Predicted Std (ug/m3)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title(f'{model_name}', fontsize=13, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


# ==============================================================================
# Selective Prediction Plots
# ==============================================================================

def plot_selective_prediction(selective_results, model_names, colors, save_path=None):
    """
    Plot selective prediction analysis results.
    Args:
        selective_results: Dictionary with selective prediction results per model
        model_names: List of model names
        colors: List of colors
        save_path: Optional path to save the figure
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    metrics_to_plot = ['MAE', 'RMSE', 'R2', 'mean_std']
    metric_labels = {
        'MAE': 'MAE (ug/m3) (lower is better)',
        'RMSE': 'RMSE (ug/m3) (lower is better)',
        'R2': 'R2 Score (higher is better)',
        'mean_std': 'Mean Std (ug/m3)',
    }
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes[idx]
        for model_idx, model_name in enumerate(model_names):
            sp_res = selective_results[model_name]
            retentions = [r['retention'] * 100 for r in sp_res]
            values = [r[metric] for r in sp_res]
            ax.plot(retentions, values, 'o-', color=colors[model_idx],
                    linewidth=2, markersize=6, label=model_name)
        ax.set_xlabel('Retention Rate (%)', fontsize=12)
        ax.set_ylabel(metric_labels[metric], fontsize=12)
        ax.set_title(f'{metric_labels[metric]}', fontsize=13, fontweight='bold')
        ax.legend(loc='best')
        ax.grid(True, alpha=0.3)
        ax.invert_xaxis()  # Higher retention on left
        # Add vertical line at 80% retention
        ax.axvline(80, color='gray', linestyle='--', alpha=0.5)

    # Turn off unused axes
    axes[4].axis('off')
    axes[5].axis('off')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


# ==============================================================================
# Radar Plots
# ==============================================================================

def _normalize_metrics(values, higher_is_better):
    """Normalize metrics to [0, 1] range for radar plots."""
    values = np.array(values, dtype=float)
    vmin, vmax = np.min(values), np.max(values)
    if np.isclose(vmax, vmin):
        return np.ones_like(values) * 0.5
    if higher_is_better:
        return (values - vmin) / (vmax - vmin)
    return (vmax - values) / (vmax - vmin)


def _plot_radar(ax, labels, data, title, colors):
    """Helper function to plot a radar chart."""
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]

    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    ax.set_thetagrids(np.degrees(angles[:-1]), labels)
    ax.set_ylim(0, 1)
    ax.set_title(title, y=1.1, fontsize=12, fontweight='bold')

    for name, values in data.items():
        vals = list(values) + list(values[:1])
        color = colors.get(name, None)
        ax.plot(angles, vals, linewidth=2, label=name, color=color)
        ax.fill(angles, vals, alpha=0.1, color=color)


def plot_radar_comparison(results, uncertainty_results, model_names, uq_model_names, colors_dict, save_path=None):
    """
    Plot radar charts for predictive performance and calibration/UQ.
    Args:
        results: Dictionary with point prediction results per model
        uncertainty_results: Dictionary with uncertainty results per model
        model_names: List of all model names (for performance radar)
        uq_model_names: List of UQ model names (for UQ radar)
        colors_dict: Dictionary mapping model names to colors
        save_path: Optional path to save the figure
    """
    # Radar 1: Predictive performance
    perf_metrics = ['MAE', 'RMSE', 'R2']
    perf_higher_is_better = {'MAE': False, 'RMSE': False, 'R2': True}

    perf_values = {m: [] for m in model_names}
    for metric in perf_metrics:
        vals = [results[m]['test'][metric] for m in model_names]
        norm = _normalize_metrics(vals, perf_higher_is_better[metric])
        for m, v in zip(model_names, norm):
            perf_values[m].append(v)

    # Radar 2: Calibration / UQ
    uq_metrics = ['NLL', 'CRPS', 'ECE', 'PICP', 'MPIW', 'Calib_AUC']
    uq_higher_is_better = {'NLL': False, 'CRPS': False, 'ECE': False,
                           'PICP': True, 'MPIW': False, 'Calib_AUC': False}

    uq_values = {m: [] for m in uq_model_names}
    for metric in uq_metrics:
        vals = []
        for m in uq_model_names:
            if metric == 'Calib_AUC':
                vals.append(uncertainty_results[m].get('calibration_auc', 0))
            else:
                vals.append(uncertainty_results[m][metric])
        norm = _normalize_metrics(vals, uq_higher_is_better[metric])
        for m, v in zip(uq_model_names, norm):
            uq_values[m].append(v)

    fig = plt.figure(figsize=(14, 6))
    ax1 = plt.subplot(1, 2, 1, polar=True)
    _plot_radar(ax1, perf_metrics, perf_values, 'Predictive Performance (Normalized)', colors_dict)
    ax1.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))

    ax2 = plt.subplot(1, 2, 2, polar=True)
    _plot_radar(ax2, uq_metrics, uq_values, 'Calibration / UQ (Normalized)', colors_dict)
    ax2.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


# ==============================================================================
# Singular Value Analysis Plots
# ==============================================================================

def plot_comprehensive_radar(results, uncertainty_results, ood_df, uq_model_names, colors_dict, save_path=None):
    """
    Plot a comprehensive radar chart combining predictive performance, calibration, UQ, and OOD metrics.

    Args:
        results: Dictionary with point prediction results per model
        uncertainty_results: Dictionary with uncertainty results per model
        ood_df: DataFrame with OOD detection metrics (AUROC, AUPR, FPR@95)
        uq_model_names: List of UQ model names to compare
        colors_dict: Dictionary mapping model names to colors
        save_path: Optional path to save the figure

    Metrics included:
        - R2 (higher is better)
        - Calib_AUC (lower is better)
        - ECE (lower is better)
        - PICP (higher is better)
        - NLL (lower is better)
        - AUROC_OOD (higher is better)
        - AUPR_OOD (higher is better)
    """
    # Define metrics and their directions
    metrics = ['R2', 'Calib_AUC', 'ECE', 'PICP', 'NLL', 'AUROC_OOD', 'AUPR_OOD']
    higher_is_better = {
        'R2': True,
        'Calib_AUC': False,
        'ECE': False,
        'PICP': True,
        'NLL': False,
        'AUROC_OOD': True,
        'AUPR_OOD': True
    }

    # Build OOD lookup from dataframe
    ood_lookup = {}
    for _, row in ood_df.iterrows():
        ood_lookup[row['Model']] = {'AUROC': row['AUROC'], 'AUPR': row['AUPR']}

    # Collect raw values for each metric
    raw_values = {m: [] for m in metrics}
    for model_name in uq_model_names:
        # Point prediction
        raw_values['R2'].append(results[model_name]['test']['R2'])
        # Uncertainty/Calibration
        raw_values['Calib_AUC'].append(uncertainty_results[model_name].get('calibration_auc', 0))
        raw_values['ECE'].append(uncertainty_results[model_name]['ECE'])
        raw_values['PICP'].append(uncertainty_results[model_name]['PICP'])
        raw_values['NLL'].append(uncertainty_results[model_name]['NLL'])
        # OOD metrics
        raw_values['AUROC_OOD'].append(ood_lookup[model_name]['AUROC'])
        raw_values['AUPR_OOD'].append(ood_lookup[model_name]['AUPR'])

    # Normalize each metric
    normalized_values = {m: [] for m in uq_model_names}
    for metric in metrics:
        vals = np.array(raw_values[metric], dtype=float)
        norm = _normalize_metrics(vals, higher_is_better[metric])
        for i, model_name in enumerate(uq_model_names):
            normalized_values[model_name].append(norm[i])

    # Create radar plot
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, polar=True)

    # Radar chart labels
    labels = ['R²', 'Calib AUC', 'ECE', 'PICP', 'NLL', 'AUROC\n(OOD)', 'AUPR\n(OOD)']
    _plot_radar(ax, labels, normalized_values,
                'Comprehensive Model Comparison\n(outer = better)', colors_dict)
    ax.legend(loc='upper right', bbox_to_anchor=(1.35, 1.1), fontsize=10)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


def plot_radar_from_dataframe(df, metrics=None, save_path=None):
    """
    Plot a comprehensive radar chart from a multi-seed summary DataFrame.

    This function reads aggregated results (mean values) from a DataFrame and
    creates a radar plot comparing all models. Uses solid lines with distinct
    colors for each model.

    Args:
        df: DataFrame with columns like 'Model', 'R2_mean', 'ECE_mean', etc.
            Expected from multi-seed experiment summary CSV.
        metrics: List of metric names (without _mean suffix) to include.
                 If None, uses default set: ['R2', 'ECE', 'PICP', 'NLL', 'CRPS',
                 'CWC', 'AUROC_OOD', 'AUPR_OOD']
        save_path: Optional path to save the figure

    Note:
        - Deterministic model will show 0 for UQ metrics (NaN values)
        - CWC (Coverage Width-based Criterion) combines PICP and MPIW
        - All metrics are normalized to [0, 1] where outer = better
    """
    # Default metrics (no MPIW - use CWC instead which combines PICP+MPIW)
    if metrics is None:
        metrics = ['R2', 'ECE', 'PICP', 'NLL', 'CRPS', 'CWC', 'AUROC_OOD', 'AUPR_OOD']

    # Direction: True = higher is better, False = lower is better
    higher_is_better = {
        'R2': True,
        'ECE': False,
        'PICP': True,
        'NLL': False,
        'CRPS': False,
        'CWC': False,
        'AUROC_OOD': True,
        'AUPR_OOD': True,
        'FPR95_OOD': False,
        'MAE': False,
        'RMSE': False,
    }

    # Display labels for metrics
    display_labels = {
        'R2': 'R²',
        'ECE': 'ECE',
        'PICP': 'PICP',
        'NLL': 'NLL',
        'CRPS': 'CRPS',
        'CWC': 'CWC',
        'AUROC_OOD': 'AUROC\n(OOD)',
        'AUPR_OOD': 'AUPR\n(OOD)',
        'FPR95_OOD': 'FPR@95',
        'MAE': 'MAE',
        'RMSE': 'RMSE',
    }

    # Colors for each model (solid lines, distinct colors)
    model_colors = {
        'Deterministic': '#1f77b4',
        'Full-Rank Bayesian': '#ff7f0e',
        'Low-Rank Bayesian': '#2ca02c',
        'Low-Rank (SVD Init)': '#d62728',
        'Rank-1 Bayesian': '#9467bd',
        'Deep Ensemble': '#8c564b',
    }

    # Filter metrics that exist in the dataframe
    available_metrics = [m for m in metrics if f'{m}_mean' in df.columns]

    if not available_metrics:
        print("No valid metrics found in DataFrame!")
        print(f"Expected columns like: {[f'{m}_mean' for m in metrics]}")
        print(f"Found columns: {df.columns.tolist()}")
        return

    # Get model names
    models = df['Model'].tolist()

    # Normalize function
    def normalize_metric(values, higher_better):
        values = np.array(values, dtype=float)
        valid_mask = ~np.isnan(values)
        if not valid_mask.any():
            return np.full_like(values, np.nan)
        vmin, vmax = np.nanmin(values), np.nanmax(values)
        if np.isclose(vmax, vmin):
            return np.where(valid_mask, 0.5, np.nan)
        if higher_better:
            result = (values - vmin) / (vmax - vmin)
        else:
            result = (vmax - values) / (vmax - vmin)
        return result

    # Build normalized values for each model
    normalized_values = {m: [] for m in models}
    labels = []

    for metric in available_metrics:
        col_name = f'{metric}_mean'
        vals = df[col_name].values
        norm = normalize_metric(vals, higher_is_better.get(metric, False))
        for i, model in enumerate(models):
            normalized_values[model].append(norm[i])
        labels.append(display_labels.get(metric, metric))

    # Create radar plot
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, polar=True)

    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]  # Complete the loop

    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    ax.set_thetagrids(np.degrees(angles[:-1]), labels, fontsize=10)
    ax.set_ylim(0, 1)

    # Plot each model with solid lines
    for model in models:
        vals = normalized_values[model]
        # Skip models with all NaN
        if all(np.isnan(v) for v in vals):
            continue
        # Replace NaN with 0 for plotting
        vals_plot = [v if not np.isnan(v) else 0 for v in vals]
        vals_plot = vals_plot + vals_plot[:1]  # Complete the loop

        color = model_colors.get(model, None)
        ax.plot(angles, vals_plot, linewidth=2.5, label=model, color=color)
        ax.fill(angles, vals_plot, alpha=0.1, color=color)

    ax.set_title('Multi-Seed Model Comparison\n(outer = better)',
                 fontsize=14, fontweight='bold', y=1.08)
    ax.legend(loc='upper right', bbox_to_anchor=(1.35, 1.1), fontsize=10)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


def plot_singular_value_analysis(model, target_energy=70.0, figsize=(16, 8), save_path=None):
    """
    Analyze and plot singular value decay for x_to_gates and h_to_gates layers.
    Rule of thumb: pick the minimum rank where the overall architecture
    (average across LSTM layers) reaches target_energy (%).
    """
    gate_layers = []
    for layer in model.layers:
        name = getattr(layer, "name", "")
        if name.endswith("_x_to_gates") or name.endswith("_h_to_gates"):
            weights = layer.get_weights()
            if weights:
                W = weights[0]
                if W.ndim == 2:
                    gate_layers.append((name, W))
    if not gate_layers:
        print("No gate layers found.")
        return None

    # Precompute cumulative energy per gate layer
    layer_energy = {}
    for name, W in gate_layers:
        s = np.linalg.svd(W, full_matrices=False, compute_uv=False)
        total_energy = np.sum(s ** 2)
        cumulative_energy = np.cumsum(s ** 2) / total_energy * 100
        layer_energy[name] = cumulative_energy

    # Group x/h gate layers by LSTM layer index
    groups = {}
    for name in layer_energy:
        m = re.match(r"layer(\d+)_", name)
        if m:
            layer_id = int(m.group(1))
            groups.setdefault(layer_id, []).append(name)

    max_rank = min(len(v) for v in layer_energy.values())
    proposed_rank = max_rank

    # Find minimal rank where overall average >= target_energy
    for r in range(1, max_rank + 1):
        layer_avgs = []
        for layer_id in sorted(groups):
            names = groups[layer_id]
            energies = [layer_energy[n][r - 1] for n in names]
            layer_avgs.append(float(np.mean(energies)))
        overall_avg = float(np.mean(layer_avgs)) if layer_avgs else 0.0
        if overall_avg >= target_energy:
            proposed_rank = r
            break

    n_layers = len(gate_layers)
    fig, axes = plt.subplots(2, n_layers, figsize=figsize)
    if n_layers == 1:
        axes = axes.reshape(2, 1)

    for col_idx, (name, W) in enumerate(gate_layers):
        cumulative_energy = layer_energy[name]
        s_len = len(cumulative_energy)
        r_idx = min(proposed_rank, s_len) - 1

        ax1 = axes[0, col_idx]
        s_vals = np.linalg.svd(W, full_matrices=False, compute_uv=False)
        ax1.plot(range(1, len(s_vals) + 1), s_vals, 'o-', color='#264653', markersize=3, lw=1.5)
        ax1.axvline(x=proposed_rank, color='#E63946', linestyle='--', lw=2, label=f'r={proposed_rank}')
        ax1.set_title(f"{name}: {W.shape}", fontsize=10, fontweight='bold')
        ax1.set_xlabel('Index (i)', fontsize=9, fontweight='bold')
        ax1.set_ylabel('sigma_i(W*)', fontsize=9, fontweight='bold')
        ax1.legend(loc='upper right', fontsize=8)
        ax1.grid(alpha=0.3, linestyle='--')
        ax1.set_yscale('log')

        ax2 = axes[1, col_idx]
        ax2.plot(range(1, len(cumulative_energy) + 1), cumulative_energy, 'o-', color='#2A9D8F', markersize=3, lw=1.5)
        ax2.axvline(x=proposed_rank, color='#E63946', linestyle='--', lw=2, label=f'{cumulative_energy[r_idx]:.1f}%')
        ax2.axhline(y=cumulative_energy[r_idx], color='#E63946', linestyle=':', lw=1.5, alpha=0.7)
        ax2.set_xlabel('Rank', fontsize=9, fontweight='bold')
        ax2.set_ylabel('Energy (%)', fontsize=9, fontweight='bold')
        ax2.set_title(f"r={proposed_rank}: {cumulative_energy[r_idx]:.1f}%", fontsize=10, fontweight='bold')
        ax2.legend(loc='lower right', fontsize=8)
        ax2.grid(alpha=0.3, linestyle='--')
        ax2.set_ylim([0, 105])

    fig.suptitle(f"Singular Value Decay (Target {target_energy:.0f}%)", fontsize=12, fontweight='bold', y=1.02)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    print('\n' + '=' * 80)
    print(f"Proposed rank (overall avg >= {target_energy:.0f}%): r = {proposed_rank}")
    print('=' * 80)
    return proposed_rank
import pandas as pd
def plot_pareto_front(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create a Pareto front plot showing trade-off between calibration (NLL/ECE)
    and OOD performance (AUROC_OOD).
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Plot 1: NLL vs AUROC_OOD
    ax1 = axes[0]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        scatter = ax1.scatter(df_kl['nll'], df_kl['auroc'],
                             label=f'KL={kl_scale}', alpha=0.7, s=100)

        # Annotate points with rank values
        for _, row in df_kl.iterrows():
            ax1.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['nll'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax1.set_xlabel('NLL (lower is better)', fontsize=11)
    ax1.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax1.set_title('NLL vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)

    # Highlight ideal region (lower-left for NLL, upper for AUROC)
    ax1.axvline(df['nll'].quantile(0.25), color='green', linestyle='--', alpha=0.3, label='25th percentile NLL')
    ax1.axhline(df['auroc'].quantile(0.75), color='green', linestyle='--', alpha=0.3)

    # Plot 2: ECE vs AUROC_OOD
    ax2 = axes[1]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        scatter = ax2.scatter(df_kl['ece'], df_kl['auroc'],
                             label=f'KL={kl_scale}', alpha=0.7, s=100)

        for _, row in df_kl.iterrows():
            ax2.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['ece'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax2.set_xlabel('ECE (lower is better)', fontsize=11)
    ax2.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax2.set_title('ECE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'figures/rank_pareto_tradeoff.png', dpi=300, bbox_inches='tight')
    plt.show()


def plot_pareto_front_performance(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create Pareto front plots showing trade-off between predictive performance
    (MAE/RMSE) and OOD performance (AUROC_OOD).
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Plot 1: MAE vs AUROC_OOD
    ax1 = axes[0]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        ax1.scatter(df_kl['mae'], df_kl['auroc'],
                    label=f'KL={kl_scale}', alpha=0.7, s=100)

        for _, row in df_kl.iterrows():
            ax1.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['mae'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax1.set_xlabel('MAE (lower is better)', fontsize=11)
    ax1.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax1.set_title('MAE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)
    ax1.axvline(df['mae'].quantile(0.25), color='green', linestyle='--', alpha=0.3)
    ax1.axhline(df['auroc'].quantile(0.75), color='green', linestyle='--', alpha=0.3)

    # Plot 2: RMSE vs AUROC_OOD
    ax2 = axes[1]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        ax2.scatter(df_kl['rmse'], df_kl['auroc'],
                    label=f'KL={kl_scale}', alpha=0.7, s=100)

        for _, row in df_kl.iterrows():
            ax2.annotate(f'({int(row["R1"])},{int(row["R2"])})',
                        (row['rmse'], row['auroc']),
                        fontsize=7, alpha=0.7)

    ax2.set_xlabel('RMSE (lower is better)', fontsize=11)
    ax2.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax2.set_title('RMSE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(alpha=0.3)
    ax2.axvline(df['rmse'].quantile(0.25), color='green', linestyle='--', alpha=0.3)
    ax2.axhline(df['auroc'].quantile(0.75), color='green', linestyle='--', alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{output_dir}/rank_pareto_performance.png', dpi=300, bbox_inches='tight')
    plt.show()


def plot_uncertainty_distributions(models_dict, X_id, X_ood, n_samples=50, save_path=None):
    """
    Plot uncertainty distributions for in-distribution vs OOD data.

    Parameters
    ----------
    models_dict : dict
        Dictionary of models to evaluate
    X_id : dict
        In-distribution data
    X_ood : dict
        Out-of-distribution data
    n_samples : int
        Number of MC samples for Bayesian models
    save_path : str, optional
        Path to save the figure
    """
    n_models = len(models_dict)
    fig, axes = plt.subplots(n_models, 2, figsize=(12, 3*n_models))

    if n_models == 1:
        axes = axes.reshape(1, -1)

    colors_id = '#2E86AB'  # Blue for ID
    colors_ood = '#C73E1D'  # Red for OOD

    for idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"Computing uncertainty for {model_name}...")

        # Get predictions
        if isinstance(model, DeepEnsemble):
            _, _, mi_id, std_id = ensemble_predictions_with_uncertainty(model, X_id)
            _, _, mi_ood, std_ood = ensemble_predictions_with_uncertainty(model, X_ood)
        else:
            #set_dropout_active(model, active=False)
            _, _, mi_id, std_id = mc_predictions_with_mi(model, X_id, n_samples)
            _, _, mi_ood, std_ood = mc_predictions_with_mi(model, X_ood, n_samples)

        # Plot STD distribution
        ax = axes[idx, 0]
        ax.hist(std_id, bins=50, alpha=0.6, color=colors_id,
               label='In-Distribution', density=True, edgecolor='white')
        ax.hist(std_ood, bins=50, alpha=0.6, color=colors_ood,
               label='OOD', density=True, edgecolor='white')
        ax.set_xlabel('Predictive Std', fontsize=10)
        ax.set_ylabel('Density', fontsize=10)
        ax.set_title(f'{model_name}: Std Distribution', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

        # Plot MI distribution
        ax = axes[idx, 1]
        # Clip MI for visualization (can be very large for some models)
        mi_id_clipped = np.clip(mi_id, 0, np.percentile(np.concatenate([mi_id, mi_ood]), 99))
        mi_ood_clipped = np.clip(mi_ood, 0, np.percentile(np.concatenate([mi_id, mi_ood]), 99))
        ax.hist(mi_id_clipped, bins=50, alpha=0.6, color=colors_id,
               label='In-Distribution', density=True, edgecolor='white')
        ax.hist(mi_ood_clipped, bins=50, alpha=0.6, color=colors_ood,
               label='OOD', density=True, edgecolor='white')
        ax.set_xlabel('Mutual Information', fontsize=10)
        ax.set_ylabel('Density', fontsize=10)
        ax.set_title(f'{model_name}: MI Distribution', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

    plt.suptitle('Uncertainty Distributions: In-Distribution vs OOD',
                fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()

def plot_metrics_radar(results_df,
                       metrics_to_plot=None,
                       highlight_models=None,
                       figsize=(12, 12),
                       save_path=None):
    """
    Create radar/spider plot to compare model performance across metrics.

    Args:
        results_df: DataFrame with models as rows, metrics as columns.
                    Can be summary_df format with 'Model' column and 'metric_mean' columns,
                    or simple format with model names as index and metric names as columns.
        metrics_to_plot: List of metrics to plot (without _mean suffix if using summary_df).
                         Default: ['RMSE', 'NLL', 'ECE', 'PICP', 'MPIW', 'AUROC_OOD', 'AUPR_OOD']
        highlight_models: Models to highlight with thicker lines
        figsize: Figure size
        save_path: Path to save figure
    """
    # Default metrics
    if metrics_to_plot is None:
        metrics_to_plot = ['RMSE', 'NLL', 'ECE', 'PICP', 'MPIW', 'AUROC_OOD', 'AUPR_OOD']

    if highlight_models is None:
        highlight_models = []

    # Handle summary_df format (has 'Model' column and 'metric_mean' columns)
    if 'Model' in results_df.columns:
        df_work = results_df.set_index('Model').copy()
        # Extract mean columns only
        mean_cols = {}
        for metric in metrics_to_plot:
            mean_col = f'{metric}_mean'
            if mean_col in df_work.columns:
                mean_cols[metric] = df_work[mean_col]
        df_work = pd.DataFrame(mean_cols)
    else:
        # Already in simple format
        df_work = results_df[metrics_to_plot].copy()

    # Drop rows with all NaN (e.g., Deterministic model for UQ metrics)
    df_work = df_work.dropna(how='all')

    # Metrics where lower is better (need to invert for radar: outer = better)
    lower_is_better = ['RMSE', 'MAE', 'NLL', 'ECE', 'MPIW', 'CRPS', 'CWC', 'FPR95_OOD']

    # Normalize metrics to [0, 1] for radar plot (higher = better after normalization)
    df_normalized = df_work.copy()
    for col in df_normalized.columns:
        values = df_normalized[col].values
        valid_mask = ~np.isnan(values)

        if not valid_mask.any():
            df_normalized[col] = 0.5
            continue

        min_val = np.nanmin(values)
        max_val = np.nanmax(values)

        # Avoid division by zero if all values are the same
        if np.isclose(max_val, min_val):
            df_normalized[col] = 0.5
        elif col in lower_is_better:
            # Lower is better - invert so that lowest value → 1.0, highest value → 0.0
            df_normalized[col] = (max_val - values) / (max_val - min_val)
        else:
            # Higher is better (PICP, AUROC, AUPR, R2, etc.)
            df_normalized[col] = (values - min_val) / (max_val - min_val)

        # Replace NaN with 0 for plotting
        df_normalized[col] = np.where(np.isnan(df_normalized[col].values), 0, df_normalized[col].values)

    # Setup radar plot
    categories = [m.replace('_', '\n') for m in df_normalized.columns]
    N = len(categories)

    angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
    angles += angles[:1]  # Complete the circle

    fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(projection='polar'))

    # Plot each model
    for idx, model_name in enumerate(df_normalized.index):
        values = df_normalized.loc[model_name].values.tolist()
        values += values[:1]  # Complete the circle

        # Determine style
        if any(highlight in model_name for highlight in highlight_models):
            linewidth = 3
            alpha = 0.9
            color = plt.cm.tab10(idx)
        else:
            linewidth = 2
            alpha = 0.7
            color = plt.cm.tab10(idx)

        ax.plot(angles, values, 'o-', linewidth=linewidth,
               label=model_name, alpha=alpha, color=color)
        ax.fill(angles, values, alpha=0.15, color=color)

    # Formatting
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, size=11)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], size=9)
    ax.grid(True, linestyle='--', alpha=0.5)

    # Push legend outside to avoid overlap on crowded plots
    plt.legend(loc='upper left', bbox_to_anchor=(1.02, 1.0),
               fontsize=10, frameon=False, ncol=1)
    plt.title('Model Performance Comparison\n(All metrics normalized to [0,1], outer=better)',
             size=14, fontweight='bold', pad=20)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved to {save_path}")

    plt.show()


def find_optimal_ranks(df: pd.DataFrame) -> pd.DataFrame:
    """
    Find the optimal rank combination for each metric and KL scale.
    Returns a summary DataFrame.
    """
    results = []

    metrics = ['NLL', 'ECE', 'AUROC_OOD_MI', 'AUPR_OOD_MI']

    for metric in metrics:
        # Find optimal overall (averaged across KL scales)
        agg = df.groupby(['R1', 'R2'])[metric].mean().reset_index()

        if metric in ['NLL', 'ECE']:
            best_row = agg.loc[agg[metric].idxmin()]
        else:
            best_row = agg.loc[agg[metric].idxmax()]

        results.append({
            'Metric': metric,
            'KL_Scale': 'All (averaged)',
            'Optimal_r1': int(best_row['R1']),
            'Optimal_r2': int(best_row['R2']),
            'Value': best_row[metric]
        })

        # Find optimal per KL scale
        for kl_scale in sorted(df['kl_scale_parsed'].unique()):
            df_kl = df[df['kl_scale_parsed'] == kl_scale]

            if metric in ['NLL', 'ECE']:
                best_row = df_kl.loc[df_kl[metric].idxmin()]
            else:
                best_row = df_kl.loc[df_kl[metric].idxmax()]

            results.append({
                'Metric': metric,
                'KL_Scale': kl_scale,
                'Optimal_r1': int(best_row['R1']),
                'Optimal_r2': int(best_row['R2']),
                'Value': best_row[metric]
            })

    return pd.DataFrame(results)





