"""
Visualization Functions for Model Comparison

This module provides comprehensive visualization tools for comparing
Bayesian neural network models across multiple metrics.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def plot_metrics_comparison(results_df,
                           metrics_to_plot=None,
                           highlight_models=None,
                           figsize=(16, 10),
                           save_path=None):
    """
    Create grouped bar plot comparing all metrics across models.

    Args:
        results_df: DataFrame with models as rows, metrics as columns
        metrics_to_plot: List of metrics to plot (default: all)
        highlight_models: List of model names to highlight (different color)
        figsize: Figure size
        save_path: Path to save figure (optional)
    """
    if metrics_to_plot is None:
        metrics_to_plot = results_df.columns.tolist()

    if highlight_models is None:
        highlight_models = []

    n_metrics = len(metrics_to_plot)
    n_models = len(results_df)

    # Create subplots
    fig, axes = plt.subplots(1, n_metrics, figsize=figsize, sharey=False)
    if n_metrics == 1:
        axes = [axes]

    # Color scheme
    colors = []
    for model_name in results_df.index:
        if any(highlight in model_name for highlight in highlight_models):
            colors.append('#2E86AB')  # Blue for highlighted models
        else:
            colors.append('#A9A9A9')  # Gray for others

    # Plot each metric
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes[idx]

        values = results_df[metric].values
        positions = np.arange(n_models)

        bars = ax.barh(positions, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)

        # Highlight best model with gold border
        # Lower is better only for ECE and NLL
        if 'ECE' in metric or 'NLL' in metric:
            best_idx = values.argmin()
        else:
            # Higher is better for all AUPR metrics (Success, Error, In_Domain, OOD) and AUROC
            best_idx = values.argmax()
        bars[best_idx].set_edgecolor('gold')
        bars[best_idx].set_linewidth(3)

        # Add value labels
        for i, (pos, val) in enumerate(zip(positions, values)):
            if any(highlight in results_df.index[i] for highlight in highlight_models):
                fontweight = 'bold'
                fontsize = 10
            else:
                fontweight = 'normal'
                fontsize = 9

            ax.text(val + 0.01, pos, f'{val:.3f}',
                   va='center', fontweight=fontweight, fontsize=fontsize)

        # Formatting
        ax.set_yticks(positions)
        ax.set_yticklabels(results_df.index, fontsize=10)
        ax.set_xlabel(metric.replace('_', ' '), fontsize=12, fontweight='bold')
        ax.set_xlim(0, values.max() * 1.15)
        ax.grid(axis='x', alpha=0.3, linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved to {save_path}")

    plt.show()


def plot_metrics_radar(results_df,
                       metrics_to_plot=None,
                       highlight_models=None,
                       figsize=(12, 12),
                       save_path=None):
    """
    Create radar/spider plot to compare model performance across metrics.

    Args:
        results_df: DataFrame with models as rows, metrics as columns
        metrics_to_plot: List of metrics to plot
        highlight_models: Models to highlight with thicker lines
        figsize: Figure size
        save_path: Path to save figure
    """
    if metrics_to_plot is None:
        metrics_to_plot = results_df.columns.tolist()

    if highlight_models is None:
        highlight_models = []

    # Normalize metrics to [0, 1] for radar plot (higher = better after normalization)
    df_normalized = results_df[metrics_to_plot].copy()
    for col in df_normalized.columns:
        min_val = df_normalized[col].min()
        max_val = df_normalized[col].max()

        # Avoid division by zero if all values are the same
        if max_val == min_val:
            df_normalized[col] = 0.5  # Set to middle value
        elif 'ECE' in col or 'NLL' in col or 'Param' in col:
            # Lower is better - invert so that lowest value → 1.0, highest value → 0.0
            # Formula: (max - value) / (max - min)
            df_normalized[col] = (max_val - df_normalized[col]) / (max_val - min_val)
        else:
            # Higher is better (including AUPR_Error, AUPR_Success, etc.)
            # Normalize so that highest value → 1.0, lowest value → 0.0
            # Formula: (value - min) / (max - min)
            df_normalized[col] = (df_normalized[col] - min_val) / (max_val - min_val)

    # Setup radar plot
    categories = [m.replace('_', '\n') for m in metrics_to_plot]
    N = len(categories)

    angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
    angles += angles[:1]  # Complete the circle

    fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(projection='polar'))

    # Plot each model
    for idx, model_name in enumerate(df_normalized.index):
        values = df_normalized.loc[model_name].values.tolist()
        values += values[:1]  # Complete the circle

        # Determine style
        if any(highlight in model_name for highlight in highlight_models):
            linewidth = 3
            alpha = 0.9
            color = plt.cm.tab10(idx)
        else:
            linewidth = 1.5
            alpha = 0.4
            color = 'gray'

        ax.plot(angles, values, 'o-', linewidth=linewidth,
               label=model_name, alpha=alpha, color=color)
        ax.fill(angles, values, alpha=0.15, color=color)

    # Formatting
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, size=11)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], size=9)
    ax.grid(True, linestyle='--', alpha=0.5)

    # Push legend outside to avoid overlap on crowded plots
    plt.legend(loc='upper left', bbox_to_anchor=(1.02, 1.0),
               fontsize=10, frameon=False, ncol=1)
    plt.title('Model Performance Comparison\n(All metrics normalized to [0,1], higher=better)',
             size=14, fontweight='bold', pad=20)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved to {save_path}")

    plt.show()


def plot_uncertainty_comparison(results_df,
                                highlight_models=None,
                                figsize=(6, 6),
                                save_path=None):
    """
    Create focused plot comparing uncertainty-related metrics.

    Args:
        results_df: DataFrame with models as rows, metrics as columns
        highlight_models: Models to highlight
        figsize: Figure size
        save_path: Path to save figure
    """
    if highlight_models is None:
        highlight_models = []

    # Select uncertainty metrics
    unc_metrics = [col for col in results_df.columns
                   if any(x in col for x in ['AUPR', 'AUROC_OOD', 'Uncertainty', 'Ratio'])]

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    #  OOD detection metrics
    ood_metrics = [col for col in unc_metrics if 'OOD' in col or 'In_Domain' in col]
  

    x = np.arange(len(results_df))
    width = 0.25
    metric_colors = ['#2E86AB', '#E94F37', '#F39C12']  # Distinct colors for each metric

    for i, metric in enumerate(ood_metrics[:3]):  # Max 3 metrics
        values = results_df[metric].values

        bars = ax.bar(x + i*width, values, width, label=metric.replace('_', ' '),
                     color=metric_colors[i], alpha=0.8, edgecolor='black')

        # Highlight specific models with full opacity and thicker edge (exact match)
        for j, model in enumerate(results_df.index):
            if model in highlight_models:
                bars[j].set_alpha(1.0)
                bars[j].set_edgecolor('black')
                bars[j].set_linewidth(2)
            else:
                bars[j].set_alpha(0.5)

    ax.set_xlabel('Models', fontsize=12, fontweight='bold')
    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('OOD Detection Performance', fontsize=13, fontweight='bold')
    ax.set_xticks(x + width)
    ax.set_xticklabels(results_df.index, rotation=45, ha='right', fontsize=9)
    # Place legend outside the axes to keep bars readable
    ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1.0), loc='upper left', frameon=False)
    ax.grid(axis='y', alpha=0.3)


    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved to {save_path}")

    plt.show()


def create_results_summary_table(results_df,
                                 highlight_models=None,
                                 save_path=None):
    """
    Create a formatted summary table highlighting best performances.

    Args:
        results_df: DataFrame with results
        highlight_models: Models to highlight
        save_path: Path to save table as image

    Returns:
        Styled DataFrame
    """
    if highlight_models is None:
        highlight_models = []

    # Create styled table
    def highlight_best(s):
        """Highlight best value in each column."""
        # Lower is better only for ECE and NLL
        if 'ECE' in s.name or 'NLL' in s.name:
            is_best = s == s.min()
        else:
            # Higher is better for all AUPR metrics (Success, Error, In_Domain, OOD) and AUROC
            is_best = s == s.max()
        return ['background-color: lightgreen; font-weight: bold' if v else '' for v in is_best]

    def highlight_models_style(row):
        """Highlight specified models."""
        if any(h in row.name for h in highlight_models):
            return ['background-color: lightblue'] * len(row)
        return [''] * len(row)

    # Format numeric columns only
    format_dict = {}
    for col in results_df.columns:
        # Check if column contains numeric data
        if results_df[col].dtype in ['float64', 'float32', 'int64', 'int32']:
            format_dict[col] = "{:.4f}"
        # String columns (like 'ECE_best_config') will use default formatting

    styled = results_df.style\
        .apply(highlight_best, axis=0)\
        .apply(highlight_models_style, axis=1)\
        .format(format_dict)\
        .set_properties(**{'text-align': 'center'})\
        .set_table_styles([
            {'selector': 'th', 'props': [('font-weight', 'bold'), ('text-align', 'center')]},
            {'selector': 'td', 'props': [('padding', '8px')]}
        ])

    if save_path:
        # Save as image requires additional dependencies
        print(f"Note: To save as image, use: styled.to_html() or screenshot")

    return styled
import matplotlib.pyplot as plt
import numpy as np

def plot_model_params(models, save_path="figures/model_params.png", figsize=(12, 8)):
    """
    Create bar plot of parameter counts for all models.
    
    Args:
        models: dict with model_name -> model (or list of models for ensemble)
        save_path: path to save figure
        figsize: figure size (width, height)
    """
    # Collect data
    names = []
    params = []
    
    for name, model in models.items():
        names.append(name)
        if isinstance(model, list):
            # Ensemble: total params across all members
            params.append(model[0].count_params() * len(model))
        else:
            params.append(model.count_params())
    
    # Create plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Colors
    colors = ['#5B9BD5', '#ED7D31', '#70AD47', '#A5A5A5', '#FF6B6B']
    
    # Bar plot
    bars = ax.bar(range(len(names)), params, color=colors[:len(names)], 
                   edgecolor='white', linewidth=1.5)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height):,}',
                ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    # Formatting
    ax.set_xticks(range(len(names)))
    ax.set_xticklabels(names, fontsize=11)
    ax.set_ylabel('Total Parameters', fontsize=13, fontweight='bold')
    ax.set_title('Total Parameters - All Models', fontsize=16, fontweight='bold', pad=20)
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    # Format y-axis
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{int(x):,}'))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Parameter plot saved to {save_path}")
    plt.show()
"""
Figure 1: Geometric Distinction Between Mean-Field and Low-Rank Posteriors

Beautiful, publication-ready figure showing the key mathematical insight:
- Mean-field posteriors have FULL SUPPORT on R^(m×n) (volume/measure > 0)
- Low-rank posteriors are CONCENTRATED on measure-zero manifold M_r (surface)

This is the main conceptual figure for the paper.
"""

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# Publication-quality settings
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 11,
    'mathtext.fontset': 'cm',
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'legend.fontsize': 10,
    'figure.dpi': 100,
    'savefig.dpi': 300,
    'text.usetex': False,  # Set to True if LaTeX is available
})

def plot_uncertainty_distributions(models_dict, X_id, X_ood, n_samples=50, save_path=None):
    """
    Plot uncertainty distributions for in-distribution vs OOD data.

    Parameters
    ----------
    models_dict : dict
        Dictionary of models to evaluate
    X_id : dict
        In-distribution data
    X_ood : dict
        Out-of-distribution data
    n_samples : int
        Number of MC samples for Bayesian models
    save_path : str, optional
        Path to save the figure
    """
    n_models = len(models_dict)
    fig, axes = plt.subplots(n_models, 2, figsize=(12, 3*n_models))

    if n_models == 1:
        axes = axes.reshape(1, -1)

    colors_id = '#2E86AB'  # Blue for ID
    colors_ood = '#C73E1D'  # Red for OOD

    for idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"Computing uncertainty for {model_name}...")

        # Get predictions
        if isinstance(model, DeepEnsemble):
            _, _, mi_id, std_id = ensemble_predictions_with_uncertainty(model, X_id)
            _, _, mi_ood, std_ood = ensemble_predictions_with_uncertainty(model, X_ood)
        else:
            #set_dropout_active(model, active=False)
            _, _, mi_id, std_id = mc_predictions_with_mi(model, X_id, n_samples)
            _, _, mi_ood, std_ood = mc_predictions_with_mi(model, X_ood, n_samples)

        # Plot STD distribution
        ax = axes[idx, 0]
        ax.hist(std_id, bins=50, alpha=0.6, color=colors_id,
               label='In-Distribution', density=True, edgecolor='white')
        ax.hist(std_ood, bins=50, alpha=0.6, color=colors_ood,
               label='OOD', density=True, edgecolor='white')
        ax.set_xlabel('Predictive Std', fontsize=10)
        ax.set_ylabel('Density', fontsize=10)
        ax.set_title(f'{model_name}: Std Distribution', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

        # Plot MI distribution
        ax = axes[idx, 1]
        # Clip MI for visualization (can be very large for some models)
        mi_id_clipped = np.clip(mi_id, 0, np.percentile(np.concatenate([mi_id, mi_ood]), 99))
        mi_ood_clipped = np.clip(mi_ood, 0, np.percentile(np.concatenate([mi_id, mi_ood]), 99))
        ax.hist(mi_id_clipped, bins=50, alpha=0.6, color=colors_id,
               label='In-Distribution', density=True, edgecolor='white')
        ax.hist(mi_ood_clipped, bins=50, alpha=0.6, color=colors_ood,
               label='OOD', density=True, edgecolor='white')
        ax.set_xlabel('Mutual Information', fontsize=10)
        ax.set_ylabel('Density', fontsize=10)
        ax.set_title(f'{model_name}: MI Distribution', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

    plt.suptitle('Uncertainty Distributions: In-Distribution vs OOD',
                fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")

    plt.show()


def draw_transparent_cube(ax, size=2.5, alpha=0.05, edge_alpha=0.3):
    """Draw a semi-transparent cube representing R^(m×n)"""
    r = [-size, size]
    X, Y = np.meshgrid(r, r)

    # All six faces
    faces = [
        # Bottom and top (z = ±size)
        (X, Y, np.ones_like(X) * -size),
        (X, Y, np.ones_like(X) * size),
        # Front and back (y = ±size)
        (X, np.ones_like(X) * -size, Y),
        (X, np.ones_like(X) * size, Y),
        # Left and right (x = ±size)
        (np.ones_like(X) * -size, X, Y),
        (np.ones_like(X) * size, X, Y),
    ]

    for face in faces:
        ax.plot_surface(face[0], face[1], face[2],
                       alpha=alpha, color='gray',
                       edgecolor='black', linewidth=0.3,
                       shade=False)


def create_figure(figsize=(18, 8), save_path='figure_1_geometric_distinction.png'):
    """
    Create conceptual figure showing geometric distinction
    between mean-field and low-rank posteriors.
    """

    fig = plt.figure(figsize=figsize)

    # ========================================================================
    # PANEL A: The Ambient Space and Manifold Structure
    # ========================================================================
    ax1 = fig.add_subplot(131, projection='3d')

    # Draw the ambient space cube
    draw_transparent_cube(ax1, size=2.5, alpha=0.06, edge_alpha=0.25)

    # Create the rank-r manifold (2D surface embedded in 3D)
    # Using a torus as a clear example of a lower-dimensional manifold
    theta = np.linspace(0, 2*np.pi, 80)
    phi = np.linspace(0, 2*np.pi, 80)
    THETA, PHI = np.meshgrid(theta, phi)

    # Torus parameters
    R = 1.5  # Major radius
    r = 0.6  # Minor radius

    X_manifold = (R + r * np.cos(PHI)) * np.cos(THETA)
    Y_manifold = (R + r * np.cos(PHI)) * np.sin(THETA)
    Z_manifold = r * np.sin(PHI)

    # Plot manifold with gradient coloring
    surf = ax1.plot_surface(X_manifold, Y_manifold, Z_manifold,
                           cmap='Blues', alpha=0.7,
                           edgecolor='none', shade=True,
                           vmin=-0.5, vmax=0.5)

    # Add subtle wireframe for structure
    ax1.plot_wireframe(X_manifold, Y_manifold, Z_manifold,
                      color="#810474", alpha=0.15, linewidth=0.4,
                      rstride=6, cstride=6)

    # Styling
    ax1.set_xlabel(r'$w_1$', fontsize=13, labelpad=10)
    ax1.set_ylabel(r'$w_2$', fontsize=13, labelpad=10)
    ax1.set_zlabel(r'$w_3$', fontsize=13, labelpad=10)
    ax1.set_xlim([-2.8, 2.8])
    ax1.set_ylim([-2.8, 2.8])
    ax1.set_zlim([-2.8, 2.8])
    ax1.view_init(elev=25, azim=50)
    ax1.grid(True, alpha=0.2, linestyle='--', linewidth=0.5)

    # Labels with mathematical notation
    ax1.text2D(0.5, 0.97, r'$\mathbb{R}^{m \times n}$',
              transform=ax1.transAxes, ha='center', fontsize=11,
              bbox=dict(boxstyle='round,pad=0.4', facecolor='white',
                       edgecolor='#424242', linewidth=1.5, alpha=0.95))

    ax1.text2D(0.5, 0.03, r'$\mathcal{M}_r = \{\mathbf{W} : \mathrm{rank}(\mathbf{W}) \leq r\}$',
              transform=ax1.transAxes, ha='center', fontsize=10,
              color='#0D47A1', style='italic',
              bbox=dict(boxstyle='round,pad=0.4', facecolor='#E3F2FD',
                       edgecolor='#1976D2', linewidth=1.3, alpha=0.95))

    ax1.set_title(r'$\mathbf{(A)}$ Weight Space Structure',
                 fontsize=14, fontweight='bold', pad=20)

    # ========================================================================
    # PANEL B: Mean-Field Posterior (Full Volumetric Support)
    # ========================================================================
    ax2 = fig.add_subplot(132, projection='3d')

    # Draw cube
    draw_transparent_cube(ax2, size=2.5, alpha=0.06, edge_alpha=0.25)

    # Draw manifold 
    ax2.plot_surface(X_manifold, Y_manifold, Z_manifold,
                    color="#2D01EE", alpha=0.25, edgecolor='none')
    ax2.plot_wireframe(X_manifold, Y_manifold, Z_manifold,
                      color="#FF3C00", alpha=0.25, linewidth=0.5,
                      rstride=8, cstride=8)

    # Mean-field: volumetric cloud of points 
    np.random.seed(42)
    n_points = 2500

    # Generate points with Gaussian distribution 
    sigma = 1.3  
    x_mf = np.random.normal(0, sigma, n_points)
    y_mf = np.random.normal(0, sigma, n_points)
    z_mf = np.random.normal(0, sigma, n_points)

    # Clip to cube bounds
    x_mf = np.clip(x_mf, -2.5, 2.5)
    y_mf = np.clip(y_mf, -2.5, 2.5)
    z_mf = np.clip(z_mf, -2.5, 2.5)

    

    # Plot volumetric cloud
    ax2.scatter(x_mf, y_mf, z_mf,
               s=20, alpha=0.5, edgecolors='cyan', linewidth=0.5,
               depthshade=True)

    # Styling
    ax2.set_xlabel(r'$w_1$', fontsize=13, labelpad=10)
    ax2.set_ylabel(r'$w_2$', fontsize=13, labelpad=10)
    ax2.set_zlabel(r'$w_3$', fontsize=13, labelpad=10)
    ax2.set_xlim([-2.8, 2.8])
    ax2.set_ylim([-2.8, 2.8])
    ax2.set_zlim([-2.8, 2.8])
    ax2.view_init(elev=25, azim=50)
    ax2.grid(True, alpha=0.2, linestyle='--', linewidth=0.5)

    # Label
    ax2.text2D(0.5, 0.97, r'$q_{\mathrm{MF}}(\mathbf{W})$',
              transform=ax2.transAxes, ha='center', fontsize=12,
              fontweight='bold', color='#B71C1C',
              bbox=dict(boxstyle='round,pad=0.5', facecolor='#FFEBEE',
                       edgecolor='#C62828', linewidth=1.8, alpha=0.95))

    ax2.text2D(0.5, 0.03, r'Full support (volume)',
              transform=ax2.transAxes, ha='center', fontsize=10,
              style='italic', color='#8B0000')

    ax2.set_title(r'$\mathbf{(B)}$ Mean-Field Posterior',
                 fontsize=14, fontweight='bold', pad=20)

    # ========================================================================
    # PANEL C: Low-Rank Posterior (Concentrated on Manifold)
    # ========================================================================
    ax3 = fig.add_subplot(133, projection='3d')

    # Draw cube
    draw_transparent_cube(ax3, size=2.5, alpha=0.06, edge_alpha=0.25)

    # Draw manifold (prominent and bold)
    ax3.plot_surface(X_manifold, Y_manifold, Z_manifold,
                    cmap='Greens', alpha=0.7, edgecolor='none',
                    shade=True, vmin=-0.5, vmax=0.5)
    ax3.plot_wireframe(X_manifold, Y_manifold, Z_manifold,
                      color="#D351CC", alpha=0.4, linewidth=0.8,
                      rstride=6, cstride=6)

    # Low-rank: points concentrated ON the manifold
    n_points_lr = 1000  # Reduced from 1200
    theta_samples = np.random.uniform(0, 2*np.pi, n_points_lr)
    phi_samples = np.random.uniform(0, 2*np.pi, n_points_lr)

    # Small noise perpendicular to surface 
    noise = 0.06
    noise_normal = np.random.normal(0, noise, n_points_lr)

    X_lr = (R + r * np.cos(phi_samples)) * np.cos(theta_samples)
    Y_lr = (R + r * np.cos(phi_samples)) * np.sin(theta_samples)
    Z_lr = r * np.sin(phi_samples)

    # Add noise (perpendicular to surface - simplified)
    X_lr += noise_normal * np.cos(theta_samples) * np.cos(phi_samples)
    Y_lr += noise_normal * np.sin(theta_samples) * np.cos(phi_samples)
    Z_lr += noise_normal * np.sin(phi_samples)

    # Color by height (not used with solid color)
    # colors_lr = Z_lr

    # Plot points on manifold - CYAN/BLUE for high contrast with green torus
    ax3.scatter(X_lr, Y_lr, Z_lr, 
               s=15, alpha=0.5, edgecolors='cyan', linewidth=0.4,
               depthshade=True)

    # Styling
    ax3.set_xlabel(r'$w_1$', fontsize=13, labelpad=10)
    ax3.set_ylabel(r'$w_2$', fontsize=13, labelpad=10)
    ax3.set_zlabel(r'$w_3$', fontsize=13, labelpad=10)
    ax3.set_xlim([-2.8, 2.8])
    ax3.set_ylim([-2.8, 2.8])
    ax3.set_zlim([-2.8, 2.8])
    ax3.view_init(elev=25, azim=50)
    ax3.grid(True, alpha=0.2, linestyle='--', linewidth=0.5)

    # Label
    ax3.text2D(0.5, 0.97, r'$q_{\mathrm{LR}}(\mathbf{A}, \mathbf{B})$',
              transform=ax3.transAxes, ha='center', fontsize=12,
              fontweight='bold', color="#03830C",
              bbox=dict(boxstyle='round,pad=0.5', facecolor="#DBF5DD",
                       edgecolor="#388E3C", linewidth=1.8, alpha=0.95))

    ax3.text2D(0.5, 0.03, r'Concentrated on $\mathcal{M}_r$ (surface)',
              transform=ax3.transAxes, ha='center', fontsize=10,
              style='italic', color='#1B5E20')

    ax3.set_title(r'$\mathbf{(C)}$ Low-Rank Posterior',
                 fontsize=14, fontweight='bold', pad=20)

    # ========================================================================
    # Overall title and caption
    # ========================================================================
    fig.suptitle('Geometric Distinction: Mean-Field vs. Low-Rank Posteriors',
                fontsize=16, fontweight='bold', y=0.98)

    plt.tight_layout(rect=[0, 0.02, 1, 0.96])  # Adjusted for removed caption

    # Save
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight',
                   facecolor='white', edgecolor='none')
        print(f"\n{'='*80}")
        print(f"✓ Figure saved to: {save_path}")
        print(f"{'='*80}")
        print(f"  • Mean-field: Full-dimensional support (volume)")
        print(f"  • Low-rank: Measure-zero manifold (surface)")
        print(f"\nThis is the conceptual foundation of the paper!")
        print(f"{'='*80}\n")

    plt.show()
    return fig


"""
Advanced Visualization Functions for Bayesian Neural Networks

This module provides publication-quality visualizations including:
- Uncertainty decomposition bar charts
- Weight correlation heatmaps
- ROC curves with confidence intervals
- Reliability diagrams
- Singular value analysis
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc
from scipy import stats
import pandas as pd


def plot_uncertainty_decomposition(models_dict, X_test, n_samples=512,
                                   num_examples=10, model_names=None,
                                   figsize=(16, 8), save_path=None):
    """
    Plot uncertainty decomposition (Total vs Epistemic vs Aleatoric) for test examples.
    Selects examples with high aleatoric uncertainty for informative visualization.

    Parameters:
    -----------
    models_dict : dict of trained models
    X_test : test features
    n_samples : number of MC samples
    num_examples : number of test examples to show
    model_names : list of model names to include (default: all)
    figsize : figure size
    save_path : path to save figure
    """
    from modules.inference import compute_total_uncertainty, compute_aleatoric_uncertainty, compute_mutual_information

    if model_names is None:
        default_names = [
            'Full-Rank BBB',
            'Low-Rank Gaussian',
            'Deep Ensemble',
            'Deterministic Dense'
        ]
        model_names = [name for name in default_names if name in models_dict]

    # Distinct colors for each model
    model_colors = {
        'Full-Rank BBB': {'epistemic': '#E63946', 'aleatoric': '#F4A261'},
        'Low-Rank Gaussian': {'epistemic': '#2A9D8F', 'aleatoric': '#A7C957'},
        'Deep Ensemble': {'epistemic': '#264653', 'aleatoric': '#457B9D'},
        'Deterministic Dense': {'epistemic': '#6C757D', 'aleatoric': '#ADB5BD'}
    }

    # First pass: compute aleatoric uncertainty for a reference model to select informative examples
    ref_name = next((name for name in model_names if name != 'Deterministic Dense'), model_names[0])
    ref_model = models_dict[ref_name]
    if isinstance(ref_model, list):
        mc_samples_ref = []
        for member in ref_model:
            preds = member.predict(X_test[:1000], verbose=0).squeeze()
            mc_samples_ref.append(preds)
        mc_samples_ref = np.array(mc_samples_ref)
    else:
        mc_samples_ref = []
        for _ in range(min(50, n_samples)):  # Quick pass with fewer samples
            preds = ref_model(X_test[:1000], training=True).numpy().squeeze()
            mc_samples_ref.append(preds)
        mc_samples_ref = np.array(mc_samples_ref)

    # Compute aleatoric uncertainty to select informative examples
    aleatoric_ref = compute_aleatoric_uncertainty(mc_samples_ref)

    # Select examples with highest aleatoric uncertainty (most informative)
    top_aleatoric_indices = np.argsort(aleatoric_ref)[-num_examples*3:]  
    # Sample evenly from top candidates
    step = len(top_aleatoric_indices) // num_examples
    example_indices = top_aleatoric_indices[::step][:num_examples]

    print(f"Selected {num_examples} examples")

    # Compute uncertainties for each model
    uncertainties = {}
    for model_name in model_names:
        model = models_dict[model_name]

        # Get MC samples for selected examples
        if isinstance(model, list):  # Deep Ensemble
            mc_samples = []
            for member in model:
                preds = member.predict(X_test[example_indices], verbose=0).squeeze()
                mc_samples.append(preds)
            mc_samples = np.array(mc_samples)  # (n_members, n_examples)
        else:  # Bayesian model
            mc_samples = []
            for _ in range(n_samples):
                preds = model(X_test[example_indices], training=True).numpy().squeeze()
                mc_samples.append(preds)
            mc_samples = np.array(mc_samples)  # (n_samples, n_examples)

        # Compute uncertainty components
        total = compute_total_uncertainty(mc_samples)
        aleatoric = compute_aleatoric_uncertainty(mc_samples)
        epistemic = compute_mutual_information(mc_samples)

        uncertainties[model_name] = {
            'total': total,
            'epistemic': epistemic,
            'aleatoric': aleatoric
        }

    # Create grouped bar chart
    fig, ax = plt.subplots(figsize=figsize)

    x = np.arange(num_examples)
    width = 0.25
    offsets = np.linspace(-width, width, len(model_names))

    for i, model_name in enumerate(model_names):
        epistemic_vals = uncertainties[model_name]['epistemic']
        aleatoric_vals = uncertainties[model_name]['aleatoric']

        colors = model_colors.get(model_name, {'epistemic': '#555555', 'aleatoric': '#888888'})

        # Stacked bars with distinct colors per model
        ax.bar(x + offsets[i], aleatoric_vals, width,
               label=f'{model_name} (Aleatoric)',
               color=colors['aleatoric'], alpha=0.8, edgecolor='black', linewidth=0.5)
        ax.bar(x + offsets[i], epistemic_vals, width, bottom=aleatoric_vals,
               label=f'{model_name} (Epistemic)',
               color=colors['epistemic'], alpha=0.8, edgecolor='black', linewidth=0.5)

    ax.set_xlabel('Test Example Index', fontsize=12, fontweight='bold')
    ax.set_ylabel('Uncertainty (nats)', fontsize=12, fontweight='bold')
    ax.set_title('Uncertainty Decomposition: Total = Epistemic + Aleatoric\n(Selected examples with high aleatoric uncertainty)',
                 fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels([f'{idx}' for idx in example_indices], rotation=0)
    ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', frameon=True, fontsize=9)
    ax.grid(axis='y', alpha=0.3, linestyle='--')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Uncertainty decomposition plot saved to '{save_path}'" if save_path else "")


def plot_weight_correlation_heatmap(models_dict, layer_idx=0, figsize=(16, 6), save_path=None):
    """
    Plot weight correlation heatmaps showing diagonal structure (full-rank) vs block structure (low-rank).

    Parameters:
    -----------
    models_dict : dict of trained models
    layer_idx : which Bayesian layer to visualize (0 = first Bayesian layer)
    figsize : figure size
    save_path : path to save figure
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    # Full-Rank BBB - find the first Bayesian layer
    fullrank_model = models_dict['Full-Rank BBB']
    bayesian_layer_count = 0
    W_fullrank = None
    for layer in fullrank_model.layers:
        if hasattr(layer, 'w_mu'):
            if bayesian_layer_count == layer_idx:
                W_fullrank = layer.w_mu.numpy()  # (input_dim, output_dim)
                print(f"Full-Rank: Found Bayesian layer with shape {W_fullrank.shape}")
                break
            bayesian_layer_count += 1

    if W_fullrank is not None:
        corr_fullrank = np.corrcoef(W_fullrank.T)  # (output_dim, output_dim)

        im1 = axes[0].imshow(corr_fullrank, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto', interpolation='nearest')
        axes[0].set_title('Full-Rank BBB: Diagonal Correlation Structure',
                         fontsize=12, fontweight='bold')
        axes[0].set_xlabel('Output Neuron', fontsize=10)
        axes[0].set_ylabel('Output Neuron', fontsize=10)
        plt.colorbar(im1, ax=axes[0], label='Correlation', fraction=0.046)
    else:
        axes[0].text(0.5, 0.5, 'No Bayesian layer found', ha='center', va='center')

    # Low-Rank Gaussian - find the first low-rank Bayesian layer
    lowrank_model = models_dict['Low-Rank Gaussian']
    bayesian_layer_count = 0
    W_lowrank = None
    for layer in lowrank_model.layers:
        if hasattr(layer, 'A_mu') and hasattr(layer, 'B_mu'):
            if bayesian_layer_count == layer_idx:
                A = layer.A_mu.numpy()  # (input_dim, rank)
                B = layer.B_mu.numpy()  # (output_dim, rank)
                W_lowrank = A @ B.T  # (input_dim, output_dim) - reconstructed weight matrix
                print(f"Low-Rank: Found Bayesian layer with shape {W_lowrank.shape} (rank={A.shape[1]})")
                break
            bayesian_layer_count += 1

    if W_lowrank is not None:
        corr_lowrank = np.corrcoef(W_lowrank.T)  # (output_dim, output_dim)

        im2 = axes[1].imshow(corr_lowrank, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto', interpolation='nearest')
        axes[1].set_title(f'Low-Rank : Block Correlation Structure',
                         fontsize=12, fontweight='bold')
        axes[1].set_xlabel('Output Neuron', fontsize=10)
        axes[1].set_ylabel('Output Neuron', fontsize=10)
        plt.colorbar(im2, ax=axes[1], label='Correlation', fraction=0.046)
    else:
        axes[1].text(0.5, 0.5, 'No low-rank layer found', ha='center', va='center')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Weight correlation heatmap saved to '{save_path}'" if save_path else "")


def plot_roc_curves_ood(models_dict, X_test, y_test, X_ood, y_ood,
                        n_samples=200, n_bootstrap=10, figsize=(12, 8), save_path=None):
    """
    Plot ROC curves for OOD detection with confidence intervals from bootstrap.

    Parameters:
    -----------
    models_dict : dict of trained models
    X_test : in-domain test features
    y_test : in-domain test labels (not used for OOD, but for shape)
    X_ood : out-of-domain features
    y_ood : out-of-domain labels (not used, just for shape)
    n_samples : number of MC samples
    n_bootstrap : number of bootstrap iterations for confidence intervals
    figsize : figure size
    save_path : path to save figure
    """
    from modules.inference import mc_predictions_with_mi_v2

    model_names = [
        'Full-Rank BBB',
        'Low-Rank Gaussian',
        'Low-Rank Laplace',
        'Rank-1 multiplicative',
        'Deep Ensemble',
        'Deterministic Dense'
    ]
    colors = {
        'Full-Rank BBB': '#E63946',
        'Low-Rank Gaussian': '#2A9D8F',
        'Low-Rank Laplace': '#277DA1',
        'Rank-1 multiplicative': '#577590',
        'Deep Ensemble': '#F4A261',
        'Deterministic Dense': '#6C757D'
    }

    fig, ax = plt.subplots(figsize=figsize)

    for model_name in model_names:
        if model_name not in models_dict:
            continue

        model = models_dict[model_name]

        # Compute MI-based uncertainty for in-domain and OOD
        _, _, mi_in = mc_predictions_with_mi_v2(model, X_test, n_samples=n_samples, seed=42)
        _, _, mi_ood = mc_predictions_with_mi_v2(model, X_ood, n_samples=n_samples, seed=42)

        # Create binary labels: 0 = in-domain, 1 = OOD
        y_true = np.concatenate([np.zeros(len(mi_in)), np.ones(len(mi_ood))])
        uncertainty_scores = np.concatenate([mi_in, mi_ood])

        # Bootstrap for confidence intervals
        aucs = []
        tprs_interp = []
        mean_fpr = np.linspace(0, 1, 100)

        for bootstrap_idx in range(n_bootstrap):
            # Resample
            indices = np.random.choice(len(y_true), size=len(y_true), replace=True)
            y_boot = y_true[indices]
            scores_boot = uncertainty_scores[indices]

            # Compute ROC
            fpr, tpr, _ = roc_curve(y_boot, scores_boot)
            aucs.append(auc(fpr, tpr))

            # Interpolate
            tpr_interp = np.interp(mean_fpr, fpr, tpr)
            tpr_interp[0] = 0.0
            tprs_interp.append(tpr_interp)

        # Plot mean ROC curve
        mean_tpr = np.mean(tprs_interp, axis=0)
        mean_auc = np.mean(aucs)
        std_auc = np.std(aucs)
        std_tpr = np.std(tprs_interp, axis=0)

        ax.plot(mean_fpr, mean_tpr, color=colors[model_name], lw=2.5,
                label=f'{model_name} (AUC = {mean_auc:.3f} ± {std_auc:.3f})')

        # Confidence interval
        tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
        tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
        ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[model_name], alpha=0.2)

    # Diagonal reference line
    ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Random Classifier')

    ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    ax.set_title('ROC Curves: OOD Detection using Mutual Information', fontsize=14, fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fontsize=10)
    ax.grid(alpha=0.3, linestyle='--')
    ax.set_xlim([-0.02, 1.02])
    ax.set_ylim([-0.02, 1.02])

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"ROC curves saved to '{save_path}'" if save_path else "")


def plot_reliability_diagrams(models_dict, X_test, y_test, n_bins=10,
                              n_samples=512, figsize=(15, 10), save_path=None):
    """
    Plot reliability diagrams (calibration plots) for each model.

    Parameters:
    -----------
    models_dict : dict of trained models
    X_test : test features
    y_test : test labels
    n_bins : number of bins for calibration
    n_samples : number of MC samples for Bayesian models
    figsize : figure size
    save_path : path to save figure
    """
    from modules.inference import mc_predictions_with_mi_v2

    model_names = list(models_dict.keys())
    n_models = len(model_names)
    ncols = 3
    nrows = int(np.ceil(n_models / ncols))

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = axes.flatten() if n_models > 1 else [axes]

    for idx, model_name in enumerate(model_names):
        model = models_dict[model_name]

        # Get predictions
        if isinstance(model, list):  # Deep Ensemble
            preds = np.mean([m.predict(X_test, verbose=0).squeeze() for m in model], axis=0)
        else:  # Bayesian or deterministic model
            preds, _, _ = mc_predictions_with_mi_v2(model, X_test, n_samples=n_samples, seed=42)

        # Compute calibration
        bin_edges = np.linspace(0, 1, n_bins + 1)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        true_probs = []
        pred_probs = []
        counts = []

        for i in range(n_bins):
            mask = (preds >= bin_edges[i]) & (preds < bin_edges[i+1])
            if mask.sum() > 0:
                true_probs.append(y_test[mask].mean())
                pred_probs.append(preds[mask].mean())
                counts.append(mask.sum())
            else:
                true_probs.append(np.nan)
                pred_probs.append(np.nan)
                counts.append(0)

        ax = axes[idx]

        # Plot calibration curve
        valid_mask = ~np.isnan(true_probs)
        ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Perfect Calibration')
        ax.plot(np.array(pred_probs)[valid_mask], np.array(true_probs)[valid_mask],
                'o-', markersize=8, lw=2, color='#E63946', label='Model Calibration')

        # Add histogram of predictions
        ax2 = ax.twinx()
        ax2.hist(preds, bins=n_bins, alpha=0.3, color='gray', edgecolor='black')
        ax2.set_ylabel('Count', fontsize=9)
        ax2.tick_params(axis='y', labelsize=8)

        # Compute ECE
        ece = np.sum(np.abs(np.array(true_probs)[valid_mask] - np.array(pred_probs)[valid_mask]) *
                     np.array(counts)[valid_mask]) / np.sum(counts)

        ax.set_xlabel('Predicted Probability', fontsize=10)
        ax.set_ylabel('True Probability', fontsize=10)
        ax.set_title(f'{model_name}\nECE = {ece:.4f}', fontsize=11, fontweight='bold')
        ax.legend(loc='upper left', fontsize=8)
        ax.grid(alpha=0.3, linestyle='--')
        ax.set_xlim([-0.02, 1.02])
        ax.set_ylim([-0.02, 1.02])

    # Hide unused subplots
    for idx in range(n_models, len(axes)):
        axes[idx].axis('off')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Reliability diagrams saved to '{save_path}'" if save_path else "")


def analyze_singular_values(model, rank=15, figsize=(16, 10), save_path=None):
    """
    Analyze and plot singular values of ALL learned weight matrices from deterministic model.

    Parameters:
    -----------
    model : trained deterministic model
    rank : chosen rank for low-rank approximation
    figsize : figure size
    save_path : path to save figure
    """
    # Extract all dense layers
    dense_layers = []
    for i, layer in enumerate(model.layers):
        if hasattr(layer, 'get_weights') and len(layer.get_weights()) > 0:
            weights = layer.get_weights()[0]
            if len(weights.shape) == 2:  # Only dense layers
                dense_layers.append((i, layer, weights))

    n_layers = len(dense_layers)
    print(f"Found {n_layers} dense layers in model")

    # Create subplots: 2 rows (singular values + cumulative energy) x n_layers columns
    fig, axes = plt.subplots(2, n_layers, figsize=figsize)
    if n_layers == 1:
        axes = axes.reshape(2, 1)

    all_energies = []

    for col_idx, (layer_idx, layer, W) in enumerate(dense_layers):
        # Compute SVD
        U, s, Vt = np.linalg.svd(W, full_matrices=False)

        # Compute cumulative energy
        total_energy = np.sum(s**2)
        cumulative_energy = np.cumsum(s**2) / total_energy * 100

        # Find energy captured by chosen rank
        energy_at_rank = cumulative_energy[min(rank-1, len(s)-1)]
        all_energies.append(energy_at_rank)

        # Plot 1: Singular values (top row)
        ax1 = axes[0, col_idx]
        ax1.plot(range(1, len(s)+1), s, 'o-', color='#264653', markersize=4, lw=2)
        if rank <= len(s):
            ax1.axvline(x=rank, color='#E63946', linestyle='--', lw=2,
                        label=f'r={rank}')
        ax1.set_xlabel('Index (i)', fontsize=10, fontweight='bold')
        ax1.set_ylabel('σᵢ(W*)', fontsize=10, fontweight='bold')
        ax1.set_title(f'Layer {layer_idx}: {W.shape}',
                      fontsize=11, fontweight='bold')
        ax1.legend(loc='upper right', fontsize=8)
        ax1.grid(alpha=0.3, linestyle='--')
        ax1.set_yscale('log')

        # Plot 2: Cumulative energy (bottom row)
        ax2 = axes[1, col_idx]
        ax2.plot(range(1, len(cumulative_energy)+1), cumulative_energy, 'o-',
                 color='#2A9D8F', markersize=4, lw=2)
        if rank <= len(s):
            ax2.axvline(x=rank, color='#E63946', linestyle='--', lw=2,
                        label=f'{energy_at_rank:.1f}%')
            ax2.axhline(y=energy_at_rank, color='#E63946', linestyle=':', lw=1.5, alpha=0.7)
        ax2.set_xlabel('Rank', fontsize=10, fontweight='bold')
        ax2.set_ylabel('Energy (%)', fontsize=10, fontweight='bold')
        ax2.set_title(f'r={rank}: {energy_at_rank:.1f}%',
                      fontsize=11, fontweight='bold')
        ax2.legend(loc='lower right', fontsize=8)
        ax2.grid(alpha=0.3, linestyle='--')
        ax2.set_ylim([0, 105])

        print(f"\nLayer {layer_idx} ({W.shape}):")
        print(f"  Total singular values: {len(s)}")
        print(f"  Energy at rank {rank}: {energy_at_rank:.2f}%")
        print(f"  Top 5 singular values: {s[:5]}")

    fig.suptitle(f'Singular Value Analysis: All Layers (Chosen rank r={rank})',
                 fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    avg_energy = np.mean(all_energies)
    print(f"\n{'='*80}")
    print(f"SUMMARY: Rank r={rank} captures {avg_energy:.2f}% energy on average across all layers")
    print(f"{'='*80}")

    if save_path:
        print(f"Plot saved to '{save_path}'")

    return all_energies


def plot_rank_heatmaps(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create heatmaps showing optimal rank combinations for each metric.
    One heatmap per KL scale value.
    """
    metrics = ['NLL', 'ECE', 'AUROC_OOD_MI', 'AUPR_OOD_MI']
    metric_labels = {
        'NLL': 'NLL (lower is better)',
        'ECE': 'ECE (lower is better)',
        'AUROC_OOD_MI': 'AUROC OOD (higher is better)',
        'AUPR_OOD_MI': 'AUPR OOD (higher is better)'
    }

    kl_scales = sorted(df['kl_scale_parsed'].unique())

    # Create figure with subplots: rows = metrics, cols = kl_scales
    fig, axes = plt.subplots(len(metrics), len(kl_scales),
                              figsize=(5*len(kl_scales), 4*len(metrics)))

    if len(kl_scales) == 1:
        axes = axes.reshape(-1, 1)

    for i, metric in enumerate(metrics):
        for j, kl_scale in enumerate(kl_scales):
            ax = axes[i, j]

            # Filter for this KL scale
            df_kl = df[df['kl_scale_parsed'] == kl_scale]

            # Create pivot table
            pivot = df_kl.pivot_table(index='r1', columns='r2', values=metric, aggfunc='mean')

            # Determine colormap direction
            if metric in ['NLL', 'ECE']:
                cmap = 'RdYlGn_r'  # Red=bad (high), Green=good (low)
                best_idx = pivot.stack().idxmin()
                best_val = pivot.stack().min()
            else:
                cmap = 'RdYlGn'  # Green=good (high), Red=bad (low)
                best_idx = pivot.stack().idxmax()
                best_val = pivot.stack().max()

            # Plot heatmap
            sns.heatmap(pivot, ax=ax, cmap=cmap, annot=True, fmt='.3f',
                       cbar_kws={'label': metric})

            # Mark best cell
            ax.set_title(f'{metric_labels[metric]}\nKL Scale={kl_scale}\nBest: r1={best_idx[0]}, r2={best_idx[1]} ({best_val:.3f})')
            ax.set_xlabel('Rank 2 (Layer 2)')
            ax.set_ylabel('Rank 1 (Layer 1)')

    plt.tight_layout()
    plt.savefig(f'{output_dir}/rank_heatmaps_by_kl.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_dir}/rank_heatmaps_by_kl.png")


def plot_rank_combined_summary(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create a summary plot showing the best rank combinations across all KL scales.
    Aggregates results to find overall optimal ranks.
    """
    metrics = ['NLL', 'ECE', 'AUROC_OOD_MI', 'AUPR_OOD_MI']

    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    axes = axes.flatten()

    for idx, metric in enumerate(metrics):
        ax = axes[idx]

        # Aggregate across all KL scales (mean)
        pivot = df.pivot_table(index='r1', columns='r2', values=metric, aggfunc='mean')

        # Determine best
        if metric in ['NLL', 'ECE']:
            cmap = 'RdYlGn_r'
            best_idx = pivot.stack().idxmin()
            best_val = pivot.stack().min()
            direction = 'lower'
        else:
            cmap = 'RdYlGn'
            best_idx = pivot.stack().idxmax()
            best_val = pivot.stack().max()
            direction = 'higher'

        # Plot
        sns.heatmap(pivot, ax=ax, cmap=cmap, annot=True, fmt='.3f',
                   cbar_kws={'label': metric}, annot_kws={'fontsize': 10})

        ax.set_title(f'{metric} ({direction} is better)\nOptimal: r1={best_idx[0]}, r2={best_idx[1]} = {best_val:.4f}',
                    fontsize=12, fontweight='bold')
        ax.set_xlabel('Rank 2 (Layer 2)', fontsize=11)
        ax.set_ylabel('Rank 1 (Layer 1)', fontsize=11)

    plt.suptitle('Optimal Rank Combinations (Averaged Across All KL Scales)',
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/rank_optimal_summary.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_dir}/rank_optimal_summary.png")


def plot_rank_vs_metrics_line(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create line plots showing how each metric changes with rank values.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    metrics = ['NLL', 'ECE', 'AUROC_OOD_MI', 'AUPR_OOD_MI']
    colors_kl = {'0.1': 'blue', '0.5': 'orange', '1.0': 'green'}

    for idx, metric in enumerate(metrics):
        ax = axes[idx // 2, idx % 2]

        # Group by r1+r2 (total rank capacity)
        df['total_rank'] = df['r1'] + df['r2']

        for kl_scale in sorted(df['kl_scale_parsed'].unique()):
            df_kl = df[df['kl_scale_parsed'] == kl_scale]
            grouped = df_kl.groupby('total_rank')[metric].agg(['mean', 'std']).reset_index()

            ax.errorbar(grouped['total_rank'], grouped['mean'],
                       yerr=grouped['std'],
                       label=f'KL={kl_scale}',
                       marker='o', capsize=3, linewidth=2)

        ax.set_xlabel('Total Rank (r1 + r2)', fontsize=11)
        ax.set_ylabel(metric, fontsize=11)
        ax.set_title(f'{metric} vs Total Rank', fontsize=12, fontweight='bold')
        ax.legend()
        ax.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{output_dir}/rank_vs_metrics_line.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_dir}/rank_vs_metrics_line.png")


def plot_pareto_front(df: pd.DataFrame, output_dir: str = 'figures'):
    """
    Create a Pareto front plot showing trade-off between calibration (NLL/ECE)
    and OOD performance (AUROC_OOD).
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Plot 1: NLL vs AUROC_OOD
    ax1 = axes[0]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        scatter = ax1.scatter(df_kl['NLL'], df_kl['AUROC_OOD_MI'],
                             label=f'KL={kl_scale}', alpha=0.7, s=100)

        # Annotate points with rank values
        for _, row in df_kl.iterrows():
            ax1.annotate(f'({int(row["r1"])},{int(row["r2"])})',
                        (row['NLL'], row['AUROC_OOD_MI']),
                        fontsize=7, alpha=0.7)

    ax1.set_xlabel('NLL (lower is better)', fontsize=11)
    ax1.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax1.set_title('NLL vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)

    # Highlight ideal region (lower-left for NLL, upper for AUROC)
    ax1.axvline(df['NLL'].quantile(0.25), color='green', linestyle='--', alpha=0.3, label='25th percentile NLL')
    ax1.axhline(df['AUROC_OOD_MI'].quantile(0.75), color='green', linestyle='--', alpha=0.3)

    # Plot 2: ECE vs AUROC_OOD
    ax2 = axes[1]
    for kl_scale in sorted(df['kl_scale_parsed'].unique()):
        df_kl = df[df['kl_scale_parsed'] == kl_scale]
        scatter = ax2.scatter(df_kl['ECE'], df_kl['AUROC_OOD_MI'],
                             label=f'KL={kl_scale}', alpha=0.7, s=100)

        for _, row in df_kl.iterrows():
            ax2.annotate(f'({int(row["r1"])},{int(row["r2"])})',
                        (row['ECE'], row['AUROC_OOD_MI']),
                        fontsize=7, alpha=0.7)

    ax2.set_xlabel('ECE (lower is better)', fontsize=11)
    ax2.set_ylabel('AUROC OOD (higher is better)', fontsize=11)
    ax2.set_title('ECE vs OOD Detection Trade-off', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{output_dir}/rank_pareto_tradeoff.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_dir}/rank_pareto_tradeoff.png")

def find_optimal_ranks(df: pd.DataFrame) -> pd.DataFrame:
    """
    Find the optimal rank combination for each metric and KL scale.
    Returns a summary DataFrame.
    """
    results = []

    metrics = ['NLL', 'ECE', 'AUROC_OOD_MI', 'AUPR_OOD_MI']

    for metric in metrics:
        # Find optimal overall (averaged across KL scales)
        agg = df.groupby(['r1', 'r2'])[metric].mean().reset_index()

        if metric in ['NLL', 'ECE']:
            best_row = agg.loc[agg[metric].idxmin()]
        else:
            best_row = agg.loc[agg[metric].idxmax()]

        results.append({
            'Metric': metric,
            'KL_Scale': 'All (averaged)',
            'Optimal_r1': int(best_row['r1']),
            'Optimal_r2': int(best_row['r2']),
            'Value': best_row[metric]
        })

        # Find optimal per KL scale
        for kl_scale in sorted(df['kl_scale_parsed'].unique()):
            df_kl = df[df['kl_scale_parsed'] == kl_scale]

            if metric in ['NLL', 'ECE']:
                best_row = df_kl.loc[df_kl[metric].idxmin()]
            else:
                best_row = df_kl.loc[df_kl[metric].idxmax()]

            results.append({
                'Metric': metric,
                'KL_Scale': kl_scale,
                'Optimal_r1': int(best_row['r1']),
                'Optimal_r2': int(best_row['r2']),
                'Value': best_row[metric]
            })

    return pd.DataFrame(results)





