"""
Paper figure plotting script for multimodal parametric simulation results.
This script creates only the squished_all combined plot.
"""

import os
import re
import json
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Configure matplotlib for publication-ready plots
plt.rcParams.update({
    'font.size': 11,          # Base font size (11pt)
    'axes.titlesize': 11,     # Title font size
    'axes.labelsize': 11,     # X and Y label font size
    'xtick.labelsize': 10,    # X tick label font size
    'ytick.labelsize': 10,    # Y tick label font size
    'legend.fontsize': 10,    # Legend font size
    'figure.titlesize': 12,   # Figure title font size
    'font.family': 'serif',   # Use serif font for publications
    'figure.dpi': 300,        # High resolution
    'savefig.dpi': 300,       # High resolution for saved figures
    'savefig.bbox': 'tight',  # Tight bounding box
    'savefig.pad_inches': 0.05,  # Small padding
    'axes.linewidth': 0.8,    # Thinner axes lines
    'grid.linewidth': 0.5,    # Thinner grid lines
})

def parse_run_name(run_name):
    """Parse run name to extract parameters"""
    # Extract data version
    data_version_match = re.search(r'_v-([^_]+)', run_name)
    data_version = data_version_match.group(1) if data_version_match else 'unknown'
    
    # Extract paired status
    paired_match = re.search(r'_paired-([^_]+)', run_name)
    paired = paired_match.group(1) == 'True' if paired_match else False
    
    # Extract random seed
    seed_match = re.search(r'_rseed-(\d+)', run_name)
    seed = int(seed_match.group(1)) if seed_match else 0
    
    # Extract sample size
    n_samples_match = re.search(r'_n-(\d+)', run_name)
    n_samples = int(n_samples_match.group(1)) if n_samples_match else 10000
    
    # Extract threshold type (default to 'relative' if not present for backward compatibility)
    threshold_match = re.search(r'_threshold-([^_]+)', run_name)
    threshold_type = threshold_match.group(1) if threshold_match else 'relative'
    
    # Extract compressibility type - check for both new and old patterns
    compressibility_type = 'linear'  # default
    
    # Check for new pattern: _compressibility-direct or _compressibility-linear
    compressibility_match = re.search(r'_compressibility-([^_]+)', run_name)
    if compressibility_match:
        compressibility_type = compressibility_match.group(1)
    # Check for old pattern: R2-direct anywhere in the name
    elif 'R2-direct' in run_name:
        compressibility_type = 'direct'
    
    # Extract reduction criterion (default to 'r_squared' if not present)
    reduction_match = re.search(r'_reduction-([^_]+)', run_name)
    reduction_criterion = reduction_match.group(1) if reduction_match else 'r_squared'
    
    return {
        'data_version': data_version,
        'paired': paired,
        'seed': seed,
        'n_samples': n_samples,
        'threshold_type': threshold_type,
        'compressibility_type': compressibility_type,
        'reduction_criterion': reduction_criterion
    }

def extract_ranks_from_string(rank_string):
    """Extract individual ranks from comma-separated string like '50, 45, 30'"""
    try:
        return [int(x.strip()) for x in rank_string.split(',')]
    except:
        return [0, 0, 0]  # fallback

def extract_values_from_string(value_string):
    """Extract individual values from comma-separated string like '0.83, 0.75'"""
    try:
        return [float(x.strip()) for x in value_string.split(',')]
    except:
        return [0.0, 0.0]  # fallback

def find_seed_variants(base_run_name, reports_dir="03_results/reports/"):
    """
    Find all runs that are the same as base_run_name but with different random seeds.
    Only looks for exact matches with different _rseed-X_ values.
    
    Parameters:
    - base_run_name: The base run name (can have any seed)
    - reports_dir: Directory containing the CSV files
    
    Returns:
    - List of all seed variants found
    """
    # Create pattern by replacing the seed part with a wildcard
    # Pattern: _rseed-(\d+)_
    if '_rseed-' in base_run_name:
        base_pattern = re.sub(r'_rseed-\d+_', '_rseed-*_', base_run_name)
    else:
        # If no seed pattern found, just return the original name
        return [base_run_name]
    
    # Get all CSV files in the reports directory
    if not os.path.exists(reports_dir):
        return [base_run_name]
    
    all_files = [f[:-4] for f in os.listdir(reports_dir) if f.endswith('.csv')]  # Remove .csv extension
    
    # Convert pattern to exact regex - only match the specified pattern
    regex_pattern = base_pattern.replace('*', r'\d+')
    regex_pattern = '^' + regex_pattern + '$'
    
    # Find matching files
    seed_variants = []
    for filename in all_files:
        if re.match(regex_pattern, filename):
            seed_variants.append(filename)
    
    # If no matches found, return the original run name
    if not seed_variants:
        seed_variants = [base_run_name]
    
    # Sort by seed number for consistent ordering
    seed_variants.sort(key=lambda x: int(re.search(r'_rseed-(\d+)_', x).group(1)) if re.search(r'_rseed-(\d+)_', x) else 0)
    
    return seed_variants

def load_and_process_data(run_names):
    """Load data from CSV files for all runs"""
    all_data = []
    
    # Expand run_names to include all seed variants
    expanded_run_names = []
    for base_run_name in run_names:
        seed_variants = find_seed_variants(base_run_name)
        expanded_run_names.extend(seed_variants)
        print(f"Base run: {base_run_name}")
        print(f"  Found {len(seed_variants)} seed variants: {seed_variants}")
    
    # Remove duplicates while preserving order
    expanded_run_names = list(dict.fromkeys(expanded_run_names))
    
    for run_name in expanded_run_names:
        csv_file = f"03_results/reports/{run_name}.csv"
        
        if not os.path.exists(csv_file):
            print(f"Warning: File {csv_file} not found, skipping...")
            continue
            
        print(f"Loading data from {csv_file}...")
        
        # Load the CSV
        df = pd.read_csv(csv_file)
        
        # Parse run parameters
        run_params = parse_run_name(run_name)
        
        # Add run metadata
        for key, value in run_params.items():
            df[key] = value
        df['run_name'] = run_name
        
        # Get final ranks for each configuration
        # Group by r_square_threshold and config to get final ranks for each run
        final_data = []
        
        for (r_thresh, config), group in df.groupby(['r_square_threshold', 'config']):
            # Get the last epoch's data for this config
            last_row = group.iloc[-1]
            
            # Extract individual ranks from the ranks string
            ranks = extract_ranks_from_string(last_row['final_ranks'])
            
            # Skip runs where final ranks are [100, 100, 100] (no rank reduction occurred)
            if ranks == [100, 100, 100]:
                print(f"  Skipping config {config} with r_threshold {r_thresh} - final ranks are [100, 100, 100]")
                continue
            
            # Extract initial R² values (only 2 modalities)
            r_squares_init = extract_values_from_string(last_row['r_squares_init'])
            
            # Extract classification accuracy values
            classification_accuracy = extract_values_from_string(last_row['classification_accuracy'])
            
            # Extract label predictability values
            label_1_pred = extract_values_from_string(last_row['label_1_pred'])
            label_2_pred = extract_values_from_string(last_row['label_2_pred'])
            
            # Create entries for each latent space
            for space_idx, rank in enumerate(ranks):
                space_name = ['Shared', 'Specific 1', 'Specific 2'][space_idx]
                
                # For R² values, only use the first 2 values for the first 2 latent spaces
                # (corresponding to modalities 1 and 2)
                r_square_init = r_squares_init[space_idx] if space_idx < len(r_squares_init) else 0.0
                
                # Classification accuracy for this latent space
                class_acc = classification_accuracy[space_idx] if space_idx < len(classification_accuracy) else 0.0
                
                # Label predictability - extract for both labels from all spaces
                label_1_predictability = label_1_pred[space_idx] if space_idx < len(label_1_pred) else 0.0
                label_2_predictability = label_2_pred[space_idx] if space_idx < len(label_2_pred) else 0.0
                
                # Backward compatibility: use the original logic for the main label_predictability field
                label_pred = 0.0
                if space_idx == 1 and len(label_1_pred) > 1:
                    label_pred = label_1_pred[1]  # latent space 1 predicts label 1
                elif space_idx == 2 and len(label_2_pred) > 2:
                    label_pred = label_2_pred[2]  # latent space 2 predicts label 2
                
                # Create the base data entry
                data_entry = {
                    'r_square_threshold': r_thresh,
                    'config': config,
                    'latent_space': space_name,
                    'latent_space_idx': space_idx,
                    'final_rank': rank,
                    'r_squares_init': r_square_init,
                    'classification_accuracy': class_acc,
                    'label_predictability': label_pred,  # Keep for backward compatibility
                    'label_1_predictability': label_1_predictability,  # All spaces predicting label 1
                    'label_2_predictability': label_2_predictability,  # All spaces predicting label 2
                    'data_version': run_params['data_version'],
                    'paired': run_params['paired'],
                    'seed': run_params['seed'],
                    'n_samples': run_params['n_samples'],
                    'threshold_type': run_params['threshold_type'],
                    'compressibility_type': run_params['compressibility_type'],
                    'reduction_criterion': run_params['reduction_criterion'],
                    'run_name': run_name
                }
                
                # Add hyperparameters from CSV if they exist
                hyperparameter_columns = ['early_stopping', 'rank_reduction_frequency', 'rank_reduction_threshold', 'patience']
                for col in hyperparameter_columns:
                    if col in last_row:
                        data_entry[col] = last_row[col]
                
                final_data.append(data_entry)
        
        all_data.extend(final_data)
    
    return pd.DataFrame(all_data)

def create_plot(data, save_path, ground_truth=None):
    """Create only the combined squished_all plot"""
    # Set up the plot style
    plt.style.use('default')
    
    # Get unique data versions and reorder them
    available_versions = set(data['data_version'].unique())
    desired_order = ['small', 'imbalanced-b', 'imbalanced2', 'large2']
    unique_versions = [v for v in desired_order if v in available_versions]
    
    # Define display names for datasets
    version_display_names = {
        'small': 'Small',
        'imbalanced-b': 'Imbalanced 1', 
        'imbalanced2': 'Imbalanced 2',
        'large2': 'Large'
    }
    
    # Define color palette for subspaces
    color_palette_subspaces = {
        'Shared': '#4c72b3',      # blue
        'Specific 1': '#f48e56',  # orange  
        'Specific 2': '#ea8ab2'   # pink
    }
    
    latent_spaces = ['Shared', 'Specific 1', 'Specific 2']
    
    # ===== CREATE COMBINED SQUISHED_ALL PLOT =====
    # Create 2x4 subplot grid: top row has 4 squished plots, bottom row has 3 disentangle plots + legend
    # Figure size: 24cm x 12cm = 9.45" x 4.72" (more reasonable for 8 subplots)
    fig_combined, axes_combined = plt.subplots(2, 4, figsize=(9.45, 4.72))
    
    # TOP ROW: Squished plots (same as _squished plot but in top 4 positions)
    space_colors = color_palette_subspaces
    space_linestyles = {
        'Shared': '-',
        'Specific 1': '--',
        'Specific 2': ':'
    }
    
    # Plot squished plots in top row
    for version_idx, version in enumerate(unique_versions):
        if version_idx >= 4:  # Only show first 4 versions in top row
            break
            
        ax = axes_combined[0, version_idx]
        version_data = data[data['data_version'] == version]
        
        # Plot each latent space with different colors (same as squished plot)
        for space_idx, space_name in enumerate(latent_spaces):
            space_data = version_data[version_data['latent_space'] == space_name]
            
            if space_data.empty:
                continue
            
            # Add ground truth line for this space
            if ground_truth and version in ground_truth:
                gt_rank = ground_truth[version][space_idx]
                ax.axhline(y=gt_rank, color=space_colors[space_name], linestyle='--', alpha=0.4, linewidth=1,
                          label=f'{space_name} GT' if version_idx == len(unique_versions) - 1 else "")
            
            # Aggregate data across seeds for each r_square_threshold
            summary_data = []
            for r_thresh in space_data['r_square_threshold'].unique():
                subset = space_data[space_data['r_square_threshold'] == r_thresh]
                summary_data.append({
                    'r_square_threshold': r_thresh,
                    'rank_mean': subset['final_rank'].mean(),
                    'rank_std': subset['final_rank'].std() if len(subset) > 1 else 0,
                    'rank_count': len(subset)
                })
            summary = pd.DataFrame(summary_data)
            
            if summary.empty:
                continue
            
            color = space_colors[space_name]
            linestyle = space_linestyles[space_name]
            
            # Plot with SEM error bars if multiple seeds
            if summary['rank_count'].max() > 1:
                rank_std = summary['rank_std'].fillna(0).values
                rank_count = summary['rank_count'].values
                
                # Calculate SEM
                sem_values = []
                for i, (std, count) in enumerate(zip(rank_std, rank_count)):
                    if count > 1:
                        sem = std / np.sqrt(count)
                    else:
                        sem = 0
                    sem_values.append(sem)
                
                sem_values = np.array(sem_values)
                
                ax.errorbar(summary['r_square_threshold'], summary['rank_mean'], 
                           yerr=sem_values, 
                           color=color, 
                           linestyle=linestyle,
                           marker='o',
                           capsize=2, 
                           markersize=4,
                           linewidth=1,
                           label=space_name)
            else:
                ax.plot(summary['r_square_threshold'], summary['rank_mean'], 
                       color=color, 
                       linestyle=linestyle,
                       marker='o',
                       markersize=4,
                       linewidth=1,
                       label=space_name)
        
        # Customize subplot
        ax.set_xlabel('Threshold λ')
        ax.set_ylabel('Rank')
        display_name = version_display_names.get(version, version)
        ax.set_title(f'{display_name}')
        ax.set_xlim(0.004, 0.11)
        ax.set_ylim(0, 50)
        ax.tick_params(axis='both', which='major')
    
    # BOTTOM ROW: Disentangle plots (Classification Accuracy, Label 1, Label 2)
    
    # ===== CLASSIFICATION ACCURACY =====
    ax_class = axes_combined[1, 0]
    
    for space_idx, space_name in enumerate(latent_spaces):
        space_data = data[data['latent_space'] == space_name]
        
        if space_data.empty:
            continue
        
        summary_data = []
        for r_thresh in space_data['r_square_threshold'].unique():
            subset = space_data[space_data['r_square_threshold'] == r_thresh]
            summary_data.append({
                'r_square_threshold': r_thresh,
                'acc_mean': subset['classification_accuracy'].mean(),
                'acc_std': subset['classification_accuracy'].std() if len(subset) > 1 else 0,
                'acc_count': len(subset)
            })
        summary = pd.DataFrame(summary_data)
        
        if summary.empty:
            continue
        
        color = color_palette_subspaces[space_name]
        
        if summary['acc_count'].max() > 1:
            acc_std = summary['acc_std'].fillna(0).values
            acc_count = summary['acc_count'].values
            
            acc_sem = []
            for i, (std, count) in enumerate(zip(acc_std, acc_count)):
                if count > 1:
                    sem = std / np.sqrt(count)
                else:
                    sem = 0
                acc_sem.append(sem)
            
            acc_yerr = np.array(acc_sem)
            
            ax_class.errorbar(summary['r_square_threshold'], summary['acc_mean'], 
                            yerr=acc_yerr, 
                            color=color, 
                            linestyle='-', 
                            marker='o',
                            capsize=2, 
                            markersize=4,
                            linewidth=1,
                            label=space_name)
        else:
            ax_class.plot(summary['r_square_threshold'], summary['acc_mean'], 
                        color=color, 
                        linestyle='-', 
                        marker='o',
                        markersize=4,
                        linewidth=1,
                        label=space_name)
    
    ax_class.axhline(y=1, color='grey', linestyle='--', alpha=0.8, linewidth=1)
    ax_class.set_xlabel('Threshold λ')
    ax_class.set_ylabel('Classification Accuracy')
    ax_class.set_title('Label 0 Predictability')
    ax_class.set_ylim(0, 1.1)
    ax_class.tick_params(axis='both', which='major')
    
    # ===== LABEL 1 PREDICTABILITY =====
    ax_label1 = axes_combined[1, 1]
    
    for space_idx, space_name in enumerate(latent_spaces):
        space_data = data[data['latent_space'] == space_name]
        
        if not space_data.empty:
            summary_data = []
            for r_thresh in space_data['r_square_threshold'].unique():
                subset = space_data[space_data['r_square_threshold'] == r_thresh]
                summary_data.append({
                    'r_square_threshold': r_thresh,
                    'label_pred_mean': subset['label_1_predictability'].mean(),
                    'label_pred_std': subset['label_1_predictability'].std() if len(subset) > 1 else 0,
                    'label_pred_count': len(subset)
                })
            summary = pd.DataFrame(summary_data)
            
            if not summary.empty:
                color = color_palette_subspaces[space_name]
                
                if summary['label_pred_count'].max() > 1:
                    label_pred_std = summary['label_pred_std'].fillna(0).values
                    label_pred_count = summary['label_pred_count'].values
                    
                    label_pred_sem = []
                    for i, (std, count) in enumerate(zip(label_pred_std, label_pred_count)):
                        if count > 1:
                            sem = std / np.sqrt(count)
                        else:
                            sem = 0
                        label_pred_sem.append(sem)
                    
                    label_pred_yerr = np.array(label_pred_sem)
                    
                    ax_label1.errorbar(summary['r_square_threshold'], summary['label_pred_mean'], 
                                    yerr=label_pred_yerr, 
                                    color=color, 
                                    linestyle='-', 
                                    marker='o',
                                    capsize=2, 
                                    markersize=4,
                                    linewidth=1,
                                    label=space_name)
                else:
                    ax_label1.plot(summary['r_square_threshold'], summary['label_pred_mean'], 
                                color=color, 
                                linestyle='-', 
                                marker='o',
                                markersize=4,
                                linewidth=1,
                                label=space_name)
    
    ax_label1.axhline(y=1, color='grey', linestyle='--', alpha=0.8, linewidth=1)
    ax_label1.set_xlabel('Threshold λ')
    ax_label1.set_ylabel('Goodness of Fit (R²)')
    ax_label1.set_title('Label 1 Predictability')
    ax_label1.set_ylim(0, 1.1)
    ax_label1.tick_params(axis='both', which='major')
    
    # ===== LABEL 2 PREDICTABILITY =====
    ax_label2 = axes_combined[1, 2]
    
    for space_idx, space_name in enumerate(latent_spaces):
        space_data = data[data['latent_space'] == space_name]
        
        if not space_data.empty:
            summary_data = []
            for r_thresh in space_data['r_square_threshold'].unique():
                subset = space_data[space_data['r_square_threshold'] == r_thresh]
                summary_data.append({
                    'r_square_threshold': r_thresh,
                    'label_pred_mean': subset['label_2_predictability'].mean(),
                    'label_pred_std': subset['label_2_predictability'].std() if len(subset) > 1 else 0,
                    'label_pred_count': len(subset)
                })
            summary = pd.DataFrame(summary_data)
            
            if not summary.empty:
                color = color_palette_subspaces[space_name]
                
                if summary['label_pred_count'].max() > 1:
                    label_pred_std = summary['label_pred_std'].fillna(0).values
                    label_pred_count = summary['label_pred_count'].values
                    
                    label_pred_sem = []
                    for i, (std, count) in enumerate(zip(label_pred_std, label_pred_count)):
                        if count > 1:
                            sem = std / np.sqrt(count)
                        else:
                            sem = 0
                        label_pred_sem.append(sem)
                    
                    label_pred_yerr = np.array(label_pred_sem)
                    
                    ax_label2.errorbar(summary['r_square_threshold'], summary['label_pred_mean'], 
                                    yerr=label_pred_yerr, 
                                    color=color, 
                                    linestyle='-', 
                                    marker='o',
                                    capsize=2, 
                                    markersize=4,
                                    linewidth=1,
                                    label=space_name)
                else:
                    ax_label2.plot(summary['r_square_threshold'], summary['label_pred_mean'], 
                                color=color, 
                                linestyle='-', 
                                marker='o',
                                markersize=4,
                                linewidth=1,
                                label=space_name)
    
    ax_label2.axhline(y=1, color='grey', linestyle='--', alpha=0.8, linewidth=1)
    ax_label2.set_xlabel('Threshold λ')
    ax_label2.set_ylabel('Goodness of Fit (R²)')
    ax_label2.set_title('Label 2 Predictability')
    ax_label2.set_ylim(0, 1.1)
    ax_label2.tick_params(axis='both', which='major')
    
    # ===== LEGEND IN BOTTOM RIGHT (position [1,3]) =====
    # Use the 4th position in bottom row for squished plot legend
    ax_legend = axes_combined[1, 3]
    ax_legend.axis('off')  # Turn off axis for legend-only subplot
    
    # Create legend for squished plots including both regular lines and GT lines
    # Get handles and labels from the last subplot in top row (which has the most complete legend)
    last_subplot_idx = min(len(unique_versions) - 1, 3)  # Handle case where there are fewer than 4 versions
    handles, labels = axes_combined[0, last_subplot_idx].get_legend_handles_labels()
    ax_legend.legend(handles, labels, loc='center', 
                    title='Latent Subspaces', frameon=False)
    
    # Adjust layout for publication figure (24cm x 12cm) with more spacing
    plt.subplots_adjust(left=0.06, bottom=0.12, right=0.98, top=0.94, wspace=0.45, hspace=0.60)
    
    # Add bold A and B annotations at the top left of each row
    fig_combined.text(-0.01, 0.98, 'A', fontsize=14, fontweight='bold', ha='center', va='center')
    fig_combined.text(-0.01, 0.48, 'B', fontsize=14, fontweight='bold', ha='center', va='center')
    
    # Save combined figure
    combined_save_path = save_path.replace('.png', '_squished_all.png')
    plt.savefig(combined_save_path, dpi=300, bbox_inches='tight')
    print(f"Combined squished_all plot saved to {combined_save_path}")
    plt.close()
    
    return fig_combined

def generate_summary_statistics(data, threshold_min=0.005, threshold_max=0.1):
    """Generate summary statistics for ranks within specified threshold interval"""
    print(f"\n{'='*80}")
    print(f"SUMMARY STATISTICS FOR RANKS (R² threshold: {threshold_min} - {threshold_max})")
    print(f"{'='*80}")
    
    # Filter data to threshold interval
    filtered_data = data[
        (data['r_square_threshold'] >= threshold_min) & 
        (data['r_square_threshold'] <= threshold_max)
    ]
    
    if filtered_data.empty:
        print(f"No data found in threshold range {threshold_min} - {threshold_max}")
        return
    
    print(f"Total data points in range: {len(filtered_data)}")
    print(f"Threshold values: {sorted(filtered_data['r_square_threshold'].unique())}")
    
    # Overall statistics across all versions and spaces
    print(f"\n{'Overall Statistics (all datasets and latent spaces)':^80}")
    print("-" * 80)
    overall_stats = filtered_data['final_rank'].describe()
    print(f"Count: {overall_stats['count']:.0f}")
    print(f"Mean:  {overall_stats['mean']:.2f}")
    print(f"Std:   {overall_stats['std']:.2f}")
    print(f"Min:   {overall_stats['min']:.2f}")
    print(f"25%:   {overall_stats['25%']:.2f}")
    print(f"50%:   {overall_stats['50%']:.2f}")
    print(f"75%:   {overall_stats['75%']:.2f}")
    print(f"Max:   {overall_stats['max']:.2f}")
    
    # Statistics by data version
    print(f"\n{'Statistics by Dataset Version':^80}")
    print("-" * 80)
    version_stats = filtered_data.groupby('data_version')['final_rank'].agg([
        'count', 'mean', 'std', 'min', 'max'
    ]).round(2)
    print(version_stats)
    
    # Statistics by latent space
    print(f"\n{'Statistics by Latent Space':^80}")
    print("-" * 80)
    space_stats = filtered_data.groupby('latent_space')['final_rank'].agg([
        'count', 'mean', 'std', 'min', 'max'
    ]).round(2)
    print(space_stats)
    
    # Detailed statistics by both version and space
    print(f"\n{'Detailed Statistics by Dataset Version and Latent Space':^80}")
    print("-" * 80)
    detailed_stats = filtered_data.groupby(['data_version', 'latent_space'])['final_rank'].agg([
        'count', 'mean', 'std', 'min', 'max'
    ]).round(2)
    print(detailed_stats)
    
    # Statistics by threshold value
    print(f"\n{'Statistics by R² Threshold Value':^80}")
    print("-" * 80)
    threshold_stats = filtered_data.groupby('r_square_threshold')['final_rank'].agg([
        'count', 'mean', 'std', 'min', 'max'
    ]).round(2)
    print(threshold_stats)
    
    # Seed statistics (if multiple seeds available)
    seeds_available = filtered_data['seed'].nunique()
    if seeds_available > 1:
        print(f"\n{'Statistics by Random Seed':^80}")
        print("-" * 80)
        seed_stats = filtered_data.groupby('seed')['final_rank'].agg([
            'count', 'mean', 'std', 'min', 'max'
        ]).round(2)
        print(seed_stats)
        
        # Variance decomposition (seed vs other factors)
        print(f"\n{'Variance Analysis':^80}")
        print("-" * 80)
        total_var = filtered_data['final_rank'].var()
        
        # Calculate between-seed variance
        seed_means = filtered_data.groupby('seed')['final_rank'].mean()
        between_seed_var = seed_means.var()
        
        # Calculate within-seed variance
        within_seed_vars = []
        for seed in filtered_data['seed'].unique():
            seed_data = filtered_data[filtered_data['seed'] == seed]
            if len(seed_data) > 1:
                within_seed_vars.append(seed_data['final_rank'].var())
        within_seed_var = np.mean(within_seed_vars) if within_seed_vars else 0
        
        print(f"Total variance: {total_var:.3f}")
        print(f"Between-seed variance: {between_seed_var:.3f} ({100*between_seed_var/total_var:.1f}%)")
        print(f"Within-seed variance: {within_seed_var:.3f} ({100*within_seed_var/total_var:.1f}%)")
    else:
        print(f"\nOnly 1 unique seed found, no seed variance analysis possible.")
    
    print(f"\n{'='*80}")

def main():
    """Main function"""
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Plot rank vs R² threshold for paper multimodal parametric simulation results')
    parser.add_argument('--config', type=str, 
                       default="larrp_mm_paramsim_paper.json",
                       help='Config file name in 02_experiments/figures/ (default: larrp_mm_paramsim_paper.json)')
    args = parser.parse_args()
    
    # Construct full config file path
    config_file = f"02_experiments/figures/{args.config}"
    
    # Derive output filename from config file name
    config_basename = os.path.splitext(args.config)[0]  # Remove .json extension
    if config_basename.startswith('larrp_mm_paramsim_'):
        # Extract the part after 'larrp_mm_paramsim_'
        suffix = config_basename[len('larrp_mm_paramsim_'):]
        output_filename = f"rank_vs_rsquare_threshold_{suffix}.png"
    else:
        # Fallback if naming convention doesn't match
        output_filename = f"rank_vs_rsquare_threshold_{config_basename}.png"
    
    save_path = f"03_results/plots/mm_sim/{output_filename}"
    
    if not os.path.exists(config_file):
        print(f"Error: Configuration file {config_file} not found!")
        return
    
    with open(config_file, 'r') as f:
        config = json.load(f)
    
    run_names = config.get('runs', [])
    ground_truth = config.get('ground_truth', None)
    
    if not run_names:
        print("No run names found in configuration file!")
        return
    
    print(f"Analyzing {len(run_names)} base runs:")
    for run_name in run_names:
        print(f"  - {run_name}")
    
    # Load and process data
    data = load_and_process_data(run_names)
    
    if data.empty:
        print("No data loaded! Check if CSV files exist.")
        return
    
    print(f"Loaded data for {len(data)} configurations")
    
    # Show which seeds were found
    print("\nSeeds found for each data version:")
    for version in sorted(data['data_version'].unique()):
        version_data = data[data['data_version'] == version]
        seeds = sorted(version_data['seed'].unique())
        print(f"  {version}: seeds {seeds}")
    
    # Generate summary statistics for threshold interval 0.005 - 0.1
    generate_summary_statistics(data, threshold_min=0.005, threshold_max=0.1)
    
    # Create and save combined plot
    fig = create_plot(data, save_path, ground_truth)
    
    # Show summary statistics
    print("\nSummary statistics:")
    summary_cols = ['latent_space', 'data_version']
    print(data.groupby(summary_cols)['final_rank'].agg(['mean', 'std', 'count']))

if __name__ == "__main__":
    main()
