import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from sklearn.metrics.pairwise import euclidean_distances
from scipy.stats import ks_2samp, wasserstein_distance
from scipy.spatial.distance import cdist
import os
import glob
import argparse

# Set up argument parser
parser = argparse.ArgumentParser(description='Analyze cross-modal VAE generations')
parser.add_argument('--run_name', type=str, required=True, help='Name of the run to analyze')
parser.add_argument('--config', type=int, default=None, help='Specific config number to analyze (default: all)')
parser.add_argument('--n_clusters', type=int, default=5, help='Number of clusters expected in the data')
parser.add_argument('--save_plots', action='store_true', help='Save individual plots for each config')
args = parser.parse_args()

# Find all cross-modal generation files for this run
data_dir = "03_results/cross_modal_generations/"
if args.config is not None:
    pattern = f"{data_dir}/{args.run_name}_config-{args.config}.npz"
else:
    pattern = f"{data_dir}/{args.run_name}_config-*.npz"

files = sorted(glob.glob(pattern))
print(f"Found {len(files)} files matching pattern: {pattern}")

if len(files) == 0:
    print(f"No files found. Exiting.")
    exit(0)

# Create output directory for plots
plot_dir = f"03_results/plots/cross_modal_analysis/{args.run_name}/"
os.makedirs(plot_dir, exist_ok=True)

# Initialize results storage
all_results = []

for file_idx, file_path in enumerate(files):
    print(f"\n{'='*80}")
    print(f"Analyzing file {file_idx+1}/{len(files)}: {os.path.basename(file_path)}")
    print(f"{'='*80}")
    
    # Extract config number from filename
    config_num = int(os.path.basename(file_path).split('config-')[1].split('.npz')[0])
    
    # Load data
    data = np.load(file_path)
    
    # Extract all arrays
    n_modalities = sum(1 for key in data.keys() if key.startswith('original_mod'))
    
    print(f"\nFound {n_modalities} modalities")
    print(f"Available keys: {list(data.keys())}")
    
    # Initialize result dictionary for this config
    result = {
        'config': config_num,
        'file': os.path.basename(file_path)
    }
    
    # Get shared latents (these should be the same regardless of modality)
    shared_mu = data['shared_mu']
    n_samples = shared_mu.shape[0]
    print(f"Number of validation samples: {n_samples}")
    print(f"Shared latent dimension: {shared_mu.shape[1]}")
    
    # Perform clustering on shared latents to identify ground truth clusters
    kmeans_shared = KMeans(n_clusters=args.n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans_shared.fit_predict(shared_mu)
    
    result['n_samples'] = n_samples
    result['shared_dim'] = shared_mu.shape[1]
    result['silhouette_shared'] = silhouette_score(shared_mu, cluster_labels)
    
    print(f"\nClustering on shared latents:")
    print(f"  Silhouette score: {result['silhouette_shared']:.4f}")
    print(f"  Cluster distribution: {np.bincount(cluster_labels)}")
    
    # ============================================================================
    # Analysis 1: Cluster Assignment Consistency
    # ============================================================================
    print(f"\n{'='*80}")
    print("ANALYSIS 1: Cluster Assignment Consistency")
    print(f"{'='*80}")
    
    for mod in range(n_modalities):
        print(f"\n--- Modality {mod} ---")
        
        # Get data
        original = data[f'original_mod{mod}']
        recon = data[f'recon_mod{mod}']
        
        # Cluster reconstructions
        kmeans_recon = KMeans(n_clusters=args.n_clusters, random_state=42, n_init=10)
        recon_clusters = kmeans_recon.fit_predict(recon)
        
        # Compute agreement with shared latent clusters
        ari_recon = adjusted_rand_score(cluster_labels, recon_clusters)
        nmi_recon = normalized_mutual_info_score(cluster_labels, recon_clusters)
        
        result[f'mod{mod}_recon_ari'] = ari_recon
        result[f'mod{mod}_recon_nmi'] = nmi_recon
        
        print(f"Reconstruction cluster agreement with shared latents:")
        print(f"  ARI: {ari_recon:.4f}")
        print(f"  NMI: {nmi_recon:.4f}")
        
        # Check cross-modal generations
        for source_mod in range(n_modalities):
            if source_mod == mod:
                continue
            
            cross_key = f'cross_mod{source_mod}_to_mod{mod}'
            if cross_key not in data:
                print(f"  Warning: {cross_key} not found in data")
                continue
                
            cross_gen = data[cross_key]
            
            # Cluster cross-modal generations
            cross_clusters = kmeans_recon.predict(cross_gen)
            
            # Agreement with shared latents (from source modality)
            ari_cross = adjusted_rand_score(cluster_labels, cross_clusters)
            nmi_cross = normalized_mutual_info_score(cluster_labels, cross_clusters)
            
            # Agreement with reconstruction clusters
            ari_cross_recon = adjusted_rand_score(recon_clusters, cross_clusters)
            nmi_cross_recon = normalized_mutual_info_score(recon_clusters, cross_clusters)
            
            result[f'cross_mod{source_mod}_to_mod{mod}_ari_shared'] = ari_cross
            result[f'cross_mod{source_mod}_to_mod{mod}_nmi_shared'] = nmi_cross
            result[f'cross_mod{source_mod}_to_mod{mod}_ari_recon'] = ari_cross_recon
            result[f'cross_mod{source_mod}_to_mod{mod}_nmi_recon'] = nmi_cross_recon
            
            print(f"\nCross-modal (mod{source_mod}→mod{mod}) cluster agreement:")
            print(f"  With shared latents - ARI: {ari_cross:.4f}, NMI: {nmi_cross:.4f}")
            print(f"  With reconstructions - ARI: {ari_cross_recon:.4f}, NMI: {nmi_cross_recon:.4f}")
    
    # ============================================================================
    # Analysis 2: Distribution Comparison
    # ============================================================================
    print(f"\n{'='*80}")
    print("ANALYSIS 2: Distribution Comparison")
    print(f"{'='*80}")
    
    for mod in range(n_modalities):
        print(f"\n--- Modality {mod} ---")
        
        recon = data[f'recon_mod{mod}']
        
        for source_mod in range(n_modalities):
            if source_mod == mod:
                continue
            
            cross_key = f'cross_mod{source_mod}_to_mod{mod}'
            if cross_key not in data:
                continue
                
            cross_gen = data[cross_key]
            
            # Compute distribution statistics
            # 1. Mean and std comparison
            recon_mean = recon.mean(axis=0)
            cross_mean = cross_gen.mean(axis=0)
            recon_std = recon.std(axis=0)
            cross_std = cross_gen.std(axis=0)
            
            mean_diff = np.abs(recon_mean - cross_mean).mean()
            std_diff = np.abs(recon_std - cross_std).mean()
            
            result[f'cross_mod{source_mod}_to_mod{mod}_mean_diff'] = mean_diff
            result[f'cross_mod{source_mod}_to_mod{mod}_std_diff'] = std_diff
            
            # 2. Kolmogorov-Smirnov test (per feature, averaged)
            ks_stats = []
            ks_pvals = []
            for feat_idx in range(min(recon.shape[1], 100)):  # Limit to first 100 features
                ks_stat, ks_pval = ks_2samp(recon[:, feat_idx], cross_gen[:, feat_idx])
                ks_stats.append(ks_stat)
                ks_pvals.append(ks_pval)
            
            avg_ks_stat = np.mean(ks_stats)
            avg_ks_pval = np.mean(ks_pvals)
            
            result[f'cross_mod{source_mod}_to_mod{mod}_ks_stat'] = avg_ks_stat
            result[f'cross_mod{source_mod}_to_mod{mod}_ks_pval'] = avg_ks_pval
            
            # 3. Wasserstein distance (on PCA-reduced data for efficiency)
            pca = PCA(n_components=min(10, recon.shape[1]))
            recon_pca = pca.fit_transform(recon)
            cross_pca = pca.transform(cross_gen)
            
            wasserstein_dists = []
            for dim in range(recon_pca.shape[1]):
                wd = wasserstein_distance(recon_pca[:, dim], cross_pca[:, dim])
                wasserstein_dists.append(wd)
            
            avg_wasserstein = np.mean(wasserstein_dists)
            result[f'cross_mod{source_mod}_to_mod{mod}_wasserstein'] = avg_wasserstein
            
            # 4. Nearest neighbor analysis (are cross-generated samples close to reconstructions?)
            # Compute distances from each cross-generated sample to nearest reconstruction
            distances = cdist(cross_gen, recon, metric='euclidean')
            min_distances = distances.min(axis=1)
            avg_min_distance = min_distances.mean()
            
            result[f'cross_mod{source_mod}_to_mod{mod}_avg_nn_dist'] = avg_min_distance
            
            print(f"\nCross-modal (mod{source_mod}→mod{mod}) distribution comparison:")
            print(f"  Mean difference: {mean_diff:.6f}")
            print(f"  Std difference: {std_diff:.6f}")
            print(f"  KS statistic (avg): {avg_ks_stat:.4f}, p-value: {avg_ks_pval:.4f}")
            print(f"  Wasserstein distance (PCA): {avg_wasserstein:.4f}")
            print(f"  Avg nearest neighbor distance: {avg_min_distance:.4f}")
    
    # ============================================================================
    # Visualization (if requested)
    # ============================================================================
    if args.save_plots:
        print(f"\nGenerating visualizations...")
        
        # Determine number of cross-modal pairs
        cross_pairs = []
        for source_mod in range(n_modalities):
            for target_mod in range(n_modalities):
                if source_mod != target_mod:
                    cross_key = f'cross_mod{source_mod}_to_mod{target_mod}'
                    if cross_key in data:
                        cross_pairs.append((source_mod, target_mod, cross_key))
        
        n_pairs = len(cross_pairs)
        n_cols = min(3, n_pairs)
        n_rows = (n_pairs + n_cols - 1) // n_cols + 1  # +1 for shared space row
        
        # Create a comprehensive figure
        fig = plt.figure(figsize=(7 * n_cols, 6 * n_rows))
        
        # 1. Shared latent space visualization (PCA) - spans full width
        ax_shared = plt.subplot(n_rows, n_cols, 1)
        if shared_mu.shape[1] > 2:
            pca_shared = PCA(n_components=2)
            shared_2d = pca_shared.fit_transform(shared_mu)
        else:
            shared_2d = shared_mu[:, :2]
        
        scatter = ax_shared.scatter(shared_2d[:, 0], shared_2d[:, 1], c=cluster_labels, 
                            cmap='tab10', alpha=0.7, s=30, edgecolors='black', linewidths=0.5)
        ax_shared.set_title('Shared Latent Space', fontsize=14, fontweight='bold')
        ax_shared.set_xlabel('PC1', fontsize=12)
        ax_shared.set_ylabel('PC2', fontsize=12)
        cbar = plt.colorbar(scatter, ax=ax_shared)
        cbar.set_label('Cluster Label', fontsize=11)
        ax_shared.grid(alpha=0.3)
        
        # 2-N. Cross-modal generation comparisons
        # Each subplot shows: reconstruction (circles) vs cross-generation (triangles)
        # Both colored by true cluster labels
        plot_idx = n_cols + 1
        
        for source_mod, target_mod, cross_key in cross_pairs:
            ax = plt.subplot(n_rows, n_cols, plot_idx)
            
            # Get reconstruction and cross-generation data
            recon = data[f'recon_mod{target_mod}']
            cross_gen = data[cross_key]
            
            # Fit PCA on combined data for fair comparison
            combined = np.vstack([recon, cross_gen])
            pca_mod = PCA(n_components=2)
            combined_2d = pca_mod.fit_transform(combined)
            
            n_samples = len(recon)
            recon_2d = combined_2d[:n_samples]
            cross_2d = combined_2d[n_samples:]
            
            # Plot reconstructions (circles)
            scatter1 = ax.scatter(recon_2d[:, 0], recon_2d[:, 1], 
                                c=cluster_labels, cmap='tab10', 
                                alpha=0.5, s=40, marker='o', 
                                edgecolors='black', linewidths=0.5,
                                label='Reconstruction')
            
            # Plot cross-generations (triangles)
            scatter2 = ax.scatter(cross_2d[:, 0], cross_2d[:, 1], 
                                c=cluster_labels, cmap='tab10', 
                                alpha=0.7, s=60, marker='^', 
                                edgecolors='black', linewidths=0.5,
                                label='Cross-Generated')
            
            ax.set_title(f'Mod{source_mod} → Mod{target_mod}', 
                        fontsize=13, fontweight='bold')
            ax.set_xlabel('PC1', fontsize=11)
            ax.set_ylabel('PC2', fontsize=11)
            ax.legend(loc='best', fontsize=10, framealpha=0.9)
            ax.grid(alpha=0.3)
            
            # Add text annotation with quality metrics
            if f'cross_mod{source_mod}_to_mod{target_mod}_ari' in result:
                ari = result[f'cross_mod{source_mod}_to_mod{target_mod}_ari']
                nmi = result[f'cross_mod{source_mod}_to_mod{target_mod}_nmi']
                text = f'ARI: {ari:.3f}\nNMI: {nmi:.3f}'
                ax.text(0.02, 0.98, text, transform=ax.transAxes,
                       fontsize=10, verticalalignment='top',
                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.7))
            
            plot_idx += 1
        
        plt.suptitle(f'Cross-Modal Generation Quality - Config {config_num}', 
                    fontsize=16, fontweight='bold', y=0.998)
        plt.tight_layout()
        
        plot_file = f"{plot_dir}/config_{config_num}_cross_modal_comparison.png"
        plt.savefig(plot_file, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"Saved cross-modal comparison plot to {plot_file}")
        
        # Create an additional detailed plot for each modality showing all cross-generations together
        for target_mod in range(n_modalities):
            fig, axes = plt.subplots(1, 2, figsize=(16, 7))
            
            recon = data[f'recon_mod{target_mod}']
            pca_mod = PCA(n_components=2)
            recon_2d = pca_mod.fit_transform(recon)
            
            # Left plot: All cross-generations overlaid
            ax = axes[0]
            # Plot reconstructions as light background
            ax.scatter(recon_2d[:, 0], recon_2d[:, 1], 
                      c='lightgray', alpha=0.3, s=30, marker='o',
                      label='Reconstruction (background)')
            
            # Overlay each cross-generation
            markers = ['^', 's', 'D', 'v', '<', '>', 'p', '*']
            for idx, source_mod in enumerate(range(n_modalities)):
                if source_mod == target_mod:
                    continue
                cross_key = f'cross_mod{source_mod}_to_mod{target_mod}'
                if cross_key in data:
                    cross_gen = data[cross_key]
                    cross_2d = pca_mod.transform(cross_gen)
                    marker = markers[idx % len(markers)]
                    ax.scatter(cross_2d[:, 0], cross_2d[:, 1], 
                             c=cluster_labels, cmap='tab10',
                             alpha=0.7, s=70, marker=marker,
                             edgecolors='black', linewidths=0.5,
                             label=f'From Mod{source_mod}')
            
            ax.set_title(f'All Cross-Generations to Mod{target_mod}', 
                        fontsize=13, fontweight='bold')
            ax.set_xlabel('PC1', fontsize=11)
            ax.set_ylabel('PC2', fontsize=11)
            ax.legend(loc='best', fontsize=10, framealpha=0.9)
            ax.grid(alpha=0.3)
            
            # Right plot: Reconstruction with cluster colors
            ax = axes[1]
            scatter = ax.scatter(recon_2d[:, 0], recon_2d[:, 1], 
                               c=cluster_labels, cmap='tab10',
                               alpha=0.7, s=50, marker='o',
                               edgecolors='black', linewidths=0.5)
            ax.set_title(f'Mod{target_mod} Reconstructions (by Cluster)', 
                        fontsize=13, fontweight='bold')
            ax.set_xlabel('PC1', fontsize=11)
            ax.set_ylabel('PC2', fontsize=11)
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('Cluster Label', fontsize=11)
            ax.grid(alpha=0.3)
            
            plt.suptitle(f'Detailed View: Target Modality {target_mod} - Config {config_num}', 
                        fontsize=15, fontweight='bold')
            plt.tight_layout()
            
            detail_file = f"{plot_dir}/config_{config_num}_mod{target_mod}_detail.png"
            plt.savefig(detail_file, dpi=150, bbox_inches='tight')
            plt.close()
            print(f"Saved detailed plot to {detail_file}")
    
    all_results.append(result)

# ============================================================================
# Save aggregate results
# ============================================================================
print(f"\n{'='*80}")
print("Saving aggregate results...")
print(f"{'='*80}")

results_df = pd.DataFrame(all_results)
output_csv = f"03_results/processed/{args.run_name}_cross_modal_analysis.csv"
os.makedirs(os.path.dirname(output_csv), exist_ok=True)
results_df.to_csv(output_csv, index=False)
print(f"Saved results to {output_csv}")

# Print summary statistics
print(f"\n{'='*80}")
print("SUMMARY STATISTICS")
print(f"{'='*80}")

# Cluster consistency
cluster_cols = [col for col in results_df.columns if '_ari' in col or '_nmi' in col]
if cluster_cols:
    print("\nCluster Consistency Metrics (mean ± std):")
    for col in cluster_cols:
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        print(f"  {col}: {mean_val:.4f} ± {std_val:.4f}")

# Distribution similarity
dist_cols = [col for col in results_df.columns if any(x in col for x in ['_mean_diff', '_std_diff', '_wasserstein', '_nn_dist'])]
if dist_cols:
    print("\nDistribution Similarity Metrics (mean ± std):")
    for col in dist_cols:
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        print(f"  {col}: {mean_val:.6f} ± {std_val:.6f}")

print(f"\n{'='*80}")
print("Analysis complete!")
print(f"{'='*80}")

# ============================================================================
# Create summary plots across all configs
# ============================================================================
print(f"\nGenerating summary plots...")

# Create a comprehensive summary figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. Cluster consistency metrics (ARI and NMI)
ax = axes[0, 0]
ari_cols = [col for col in results_df.columns if '_ari' in col]
nmi_cols = [col for col in results_df.columns if '_nmi' in col]

if ari_cols or nmi_cols:
    x_pos = []
    labels = []
    means = []
    stds = []
    colors = []
    
    for i, col in enumerate(ari_cols):
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        # Only add if we have valid data
        if not np.isnan(mean_val):
            x_pos.append(len(means) * 2)
            labels.append(col.replace('_ari', '').replace('_', ' ').title())
            means.append(mean_val)
            stds.append(std_val if not np.isnan(std_val) else 0)
            colors.append('steelblue')
    
    offset = len(means) * 2 + 1 if means else 0
    for i, col in enumerate(nmi_cols):
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        # Only add if we have valid data
        if not np.isnan(mean_val):
            x_pos.append(offset + len([m for m in means if m is not None]) * 2)
            labels.append(col.replace('_nmi', '').replace('_', ' ').title())
            means.append(mean_val)
            stds.append(std_val if not np.isnan(std_val) else 0)
            colors.append('coral')
    
    if means:  # Only plot if we have data
        ax.bar(x_pos, means, yerr=stds, capsize=5, color=colors, alpha=0.7)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9)
        ax.set_ylabel('Score')
        ax.set_title('Cluster Consistency Metrics (ARI & NMI)')
        ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.5)
        ax.grid(axis='y', alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No valid cluster metrics', ha='center', va='center')
        ax.set_title('Cluster Consistency Metrics')
else:
    ax.text(0.5, 0.5, 'No cluster metrics available', ha='center', va='center')
    ax.set_title('Cluster Consistency Metrics')

# 2. Distribution similarity - Wasserstein distances
ax = axes[0, 1]
wasserstein_cols = [col for col in results_df.columns if 'wasserstein' in col]

if wasserstein_cols:
    x_pos = []
    labels = []
    means = []
    stds = []
    
    for col in wasserstein_cols:
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        if not np.isnan(mean_val):
            x_pos.append(len(means))
            labels.append(col.replace('_wasserstein', '').replace('_', ' ').title())
            means.append(mean_val)
            stds.append(std_val if not np.isnan(std_val) else 0)
    
    if means:
        ax.bar(x_pos, means, yerr=stds, capsize=5, color='seagreen', alpha=0.7)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9)
        ax.set_ylabel('Wasserstein Distance')
        ax.set_title('Distribution Similarity (Wasserstein)')
        ax.grid(axis='y', alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No valid Wasserstein metrics', ha='center', va='center')
        ax.set_title('Distribution Similarity')
else:
    ax.text(0.5, 0.5, 'No Wasserstein metrics available', ha='center', va='center')
    ax.set_title('Distribution Similarity')

# 3. Mean and std differences
ax = axes[0, 2]
mean_diff_cols = [col for col in results_df.columns if 'mean_diff' in col]
std_diff_cols = [col for col in results_df.columns if 'std_diff' in col]

if mean_diff_cols or std_diff_cols:
    x_pos = []
    labels = []
    means = []
    stds = []
    colors = []
    
    for col in mean_diff_cols:
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        if not np.isnan(mean_val):
            x_pos.append(len(means) * 2)
            labels.append('Mean: ' + col.replace('_mean_diff', '').replace('_', ' '))
            means.append(mean_val)
            stds.append(std_val if not np.isnan(std_val) else 0)
            colors.append('mediumpurple')
    
    offset = len(means) * 2 + 1 if means else 0
    for col in std_diff_cols:
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        if not np.isnan(mean_val):
            x_pos.append(offset + len([m for m in means if m is not None]) * 2)
            labels.append('Std: ' + col.replace('_std_diff', '').replace('_', ' '))
            means.append(mean_val)
            stds.append(std_val if not np.isnan(std_val) else 0)
            colors.append('orange')
    
    if means:
        ax.bar(x_pos, means, yerr=stds, capsize=5, color=colors, alpha=0.7)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
        ax.set_ylabel('Absolute Difference')
        ax.set_title('Mean & Std Differences')
        ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.5)
        ax.grid(axis='y', alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No valid mean/std metrics', ha='center', va='center')
        ax.set_title('Mean & Std Differences')
else:
    ax.text(0.5, 0.5, 'No mean/std metrics available', ha='center', va='center')
    ax.set_title('Mean & Std Differences')

# 4. Nearest neighbor distances
ax = axes[1, 0]
nn_cols = [col for col in results_df.columns if 'nn_dist' in col]

if nn_cols:
    x_pos = []
    labels = []
    means = []
    stds = []
    
    for col in nn_cols:
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        if not np.isnan(mean_val):
            x_pos.append(len(means))
            labels.append(col.replace('_nn_dist', '').replace('_', ' ').title())
            means.append(mean_val)
            stds.append(std_val if not np.isnan(std_val) else 0)
    
    if means:
        ax.bar(x_pos, means, yerr=stds, capsize=5, color='crimson', alpha=0.7)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9)
        ax.set_ylabel('Average NN Distance')
        ax.set_title('Nearest Neighbor Distances')
        ax.grid(axis='y', alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No valid NN distance metrics', ha='center', va='center')
        ax.set_title('Nearest Neighbor Distances')
else:
    ax.text(0.5, 0.5, 'No NN distance metrics available', ha='center', va='center')
    ax.set_title('Nearest Neighbor Distances')

# 5. Silhouette scores if available
ax = axes[1, 1]
sil_cols = [col for col in results_df.columns if 'silhouette' in col]

if sil_cols:
    x_pos = []
    labels = []
    means = []
    stds = []
    
    for col in sil_cols:
        mean_val = results_df[col].mean()
        std_val = results_df[col].std()
        if not np.isnan(mean_val):
            x_pos.append(len(means))
            labels.append(col.replace('_silhouette', '').replace('_', ' ').title())
            means.append(mean_val)
            stds.append(std_val if not np.isnan(std_val) else 0)
    
    if means:
        ax.bar(x_pos, means, yerr=stds, capsize=5, color='teal', alpha=0.7)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9)
        ax.set_ylabel('Silhouette Score')
        ax.set_title('Clustering Quality (Silhouette)')
        ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.5)
        ax.grid(axis='y', alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No valid silhouette metrics', ha='center', va='center')
        ax.set_title('Clustering Quality')
else:
    ax.text(0.5, 0.5, 'No silhouette metrics available', ha='center', va='center')
    ax.set_title('Clustering Quality')

# 6. Summary statistics table
ax = axes[1, 2]
ax.axis('off')

# Create a summary text
summary_text = f"Summary Statistics\n{'='*40}\n\n"
summary_text += f"Number of configs analyzed: {len(results_df)}\n\n"

# Add key metrics
if ari_cols:
    avg_ari = results_df[ari_cols].mean().mean()
    summary_text += f"Avg ARI: {avg_ari:.4f}\n"
if nmi_cols:
    avg_nmi = results_df[nmi_cols].mean().mean()
    summary_text += f"Avg NMI: {avg_nmi:.4f}\n"
if wasserstein_cols:
    avg_wass = results_df[wasserstein_cols].mean().mean()
    summary_text += f"Avg Wasserstein: {avg_wass:.4f}\n"
if nn_cols:
    avg_nn = results_df[nn_cols].mean().mean()
    summary_text += f"Avg NN Dist: {avg_nn:.4f}\n"

summary_text += f"\n{'='*40}\n"
summary_text += "Interpretation:\n"
summary_text += "• ARI/NMI close to 1: Good cluster preservation\n"
summary_text += "• Low Wasserstein: Similar distributions\n"
summary_text += "• Low NN distance: Similar local structure\n"

ax.text(0.1, 0.9, summary_text, fontsize=10, verticalalignment='top',
        family='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

plt.suptitle(f'Cross-Modal Analysis Summary: {args.run_name}', fontsize=16, y=0.995)
plt.tight_layout()

summary_plot_file = f"{plot_dir}/../{args.run_name}_summary.png"
plt.savefig(summary_plot_file, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved summary plot to {summary_plot_file}")

print(f"\n{'='*80}")
print("All plots generated successfully!")
print(f"{'='*80}")
