#!/usr/bin/env python3
"""
Uncertainty Comparison Visualization Script

This script creates side-by-side visualizations comparing uncertainty distribution
with class/OOD distribution using t-SNE.
"""

import matplotlib.pyplot as plt

# plt.rcParams.update({
#     # "text.usetex": True,
#     "font.family": "Helvetica",
#     # 'font.size': 14,
#     # 'legend.fontsize': 13,
#     # "text.latex.preamble": r'\usepackage{bm}',
# })

from visualize_features import FeatureVisualizer

def create_uncertainty_comparison():
    """
    Create uncertainty comparison visualization with your specified parameters.
    Modify the parameters below as needed.
    """
    
    # =============================================================================
    # 🔧 CONFIGURATION - MODIFY THESE PARAMETERS AS NEEDED
    # =============================================================================
    
    # Results directory
    results_dir = "/home/kemove/PythonProjects/Re-EDL/code_classical/results/2_cifar10_stat/1000-0.0_best/detailed_results"
    
    # Model configuration
    config_id = "cifar10-edl-exp-uce-test"
    seed = 1000
    
    # ID dataset and classes to visualize
    id_dataset = "CIFAR10"
    selected_classes = list(range(10))  # Classes 0, 1, 2
    
    # OOD datasets to include
    ood_datasets = ["SVHN"]  # You can add more: ["SVHN", "CIFAR100", "GTSRB"]
    
    # Visualization settings
    uncertainty_types = ["edl_mpu", "differential_entropy", "mutual_information", "alpha0"]  # Available: "max_prob", "max_alpha", "alpha0", "differential_entropy", "mutual_information", "edl_mpu"
    
    # Visual parameters for 2D t-SNE
    point_size = 15         # Size of the scatter points
    alpha = 0.8            # Transparency (0.0 = transparent, 1.0 = opaque)
    figsize = (14, 10)     # Figure size (will be automatically adjusted for side-by-side)
    
    # t-SNE parameters
    perplexity = 30.0      # t-SNE perplexity (5-50, smaller for fewer samples)
    learning_rate = 200.0  # t-SNE learning rate (10-1000)
    n_iter = 1000         # Number of iterations
    random_state = 42      # Random seed for reproducibility
    
    # Performance parameters (to speed up t-SNE)
    max_samples_per_class = 500   # Maximum samples per ID class
    max_ood_samples = 900        # Maximum OOD samples
    
    # =============================================================================
    # 🎨 CREATE VISUALIZATION
    # =============================================================================
    
    print("🎨 Creating uncertainty comparison visualization...")
    print(f"   ID Dataset: {id_dataset}")
    print(f"   Selected Classes: {selected_classes}")
    print(f"   OOD Datasets: {ood_datasets}")
    # print(f"   Uncertainty Type: {uncertainty_type}")
    print(f"   t-SNE Parameters: perplexity={perplexity}, lr={learning_rate}, iter={n_iter}")
    
    # Initialize visualizer
    visualizer = FeatureVisualizer(results_dir)
    
    # Create uncertainty comparison visualization
    visualizer.visualize_uncertainty_2d_tsne(
        config_id=config_id,
        seed=seed,
        id_dataset=id_dataset,
        selected_classes=selected_classes,
        ood_datasets=ood_datasets,
        uncertainty_types=uncertainty_types,
        # save_path=save_path,
        figsize=figsize,
        point_size=point_size,
        alpha=alpha,
        perplexity=perplexity,
        learning_rate=learning_rate,
        n_iter=n_iter,
        random_state=random_state,
        max_samples_per_class=max_samples_per_class,
        max_ood_samples=max_ood_samples,
        show_colorbar=True,
        uncertainty_range=None  # Auto-range, or set (min, max) for custom range
    )
    
    print("✅ Uncertainty comparison visualization complete!")



if __name__ == "__main__":
    # Choose which function to run:
    
    # Option 1: Create a single comparison visualization
    create_uncertainty_comparison()
    
    print("\n📋 CIFAR10 Class Labels Reference:")
    print("   0: airplane    1: automobile  2: bird       3: cat        4: deer")
    print("   5: dog         6: frog        7: horse      8: ship       9: truck")
    
    print("\n📋 Available OOD Datasets:")
    print("   - SVHN: Street View House Numbers")
    print("   - CIFAR100: CIFAR-100 dataset") 
    print("   - GTSRB: German Traffic Sign Recognition Benchmark")
    print("   - Places365: Places365 dataset")
    print("   - Food101: Food-101 dataset")
    
    print("\n📋 Available Uncertainty Types:")
    print("   - max_prob: Maximum predicted probability")
    print("   - max_alpha: Maximum alpha value")
    print("   - alpha0: Sum of alpha values (precision)")
    print("   - differential_entropy: Differential entropy")
    print("   - mutual_information: Mutual information (epistemic uncertainty)")
    print("   - edl_mpu: The proposed margin-aware predictive uncertainty")
    
    print("\n💡 Visualization Features:")
    print("   - Left plot: Uncertainty values (viridis colormap with correct orientation)")
    print("   - Right plot: ID classes (colored) + OOD datasets (different markers)")
    print("   - Side-by-side comparison for easy analysis")
    print("   - Automatic t-SNE parameter compatibility")
    print("   - Performance optimization for large datasets")
    print("   - Colorbar: Small values at bottom, large values at top") 