import pandas as pd
import numpy as np
import glob
import os
from scipy import stats

def analyze_unimodal_sweep_results():
    """
    Analyze unimodal parameter sweep results from multiple CSV files.
    
    Returns:
    - Summary statistics
    - Parameter importance analysis
    - Best parameter values for stability
    """
    
    # Find all result files
    result_files = glob.glob("03_results/paper_results/unimodal/unimodal_param_sweep_*.csv")
    print(f"Found {len(result_files)} result files:")
    for file in result_files:
        print(f"  - {file}")
    
    if not result_files:
        print("No result files found!")
        return
    
    # Load and combine all results
    all_results = []
    for file in result_files:
        df = pd.read_csv(file)
        all_results.append(df)
    
    combined_results = pd.concat(all_results, ignore_index=True)
    print(f"\nLoaded {len(combined_results)} total experiments across {len(result_files)} files")
    
    # Use final ranks for analysis (the rank at the end of training)
    combined_results['final_rank'] = combined_results['final_ranks']
    
    # Target value (true intrinsic dimensionality)
    target_rank = 5
    
    # Calculate deviation from target
    combined_results['rank_deviation'] = np.abs(combined_results['final_rank'] - target_rank)
    combined_results['rank_error'] = combined_results['final_rank'] - target_rank
    
    print(f"\nOverall Statistics:")
    print(f"Mean final rank: {combined_results['final_rank'].mean():.2f} ± {combined_results['final_rank'].sem():.2f}")
    print(f"Target rank: {target_rank}")
    print(f"Mean absolute deviation: {combined_results['rank_deviation'].mean():.2f} ± {combined_results['rank_deviation'].sem():.2f}")
    print(f"Rank range: {combined_results['final_rank'].min():.1f} - {combined_results['final_rank'].max():.1f}")
    print(f"Rank std: {combined_results['final_rank'].std():.2f}")
    
    # Define parameter columns
    param_columns = ['r_square_threshold', 'early_stopping', 'rank_reduction_frequency', 
                     'rank_reduction_threshold', 'patience', 'threshold']
    
    # 1. Report mean ± SEM rank for all sweeps
    overall_stats = {
        'metric': ['overall_mean_rank', 'overall_sem_rank', 'overall_mean_deviation', 'overall_sem_deviation', 
                   'overall_min_rank', 'overall_max_rank', 'overall_std_rank'],
        'value': [
            combined_results['final_rank'].mean(),
            combined_results['final_rank'].sem(),
            combined_results['rank_deviation'].mean(),
            combined_results['rank_deviation'].sem(),
            combined_results['final_rank'].min(),
            combined_results['final_rank'].max(),
            combined_results['final_rank'].std()
        ]
    }
    overall_df = pd.DataFrame(overall_stats)
    print(f"\n" + "="*50)
    print("OVERALL STATISTICS")
    print("="*50)
    print(overall_df.to_string(index=False))
    
    # 2. Report mean ± SEM rank per parameter tested
    param_analysis = []
    
    for param in param_columns:
        param_groups = combined_results.groupby(param)
        
        print(f"\n" + "="*50)
        print(f"PARAMETER ANALYSIS: {param.upper()}")
        print("="*50)
        
        param_stats = []
        for param_value, group in param_groups:
            mean_rank = group['final_rank'].mean()
            sem_rank = group['final_rank'].sem()
            mean_deviation = group['rank_deviation'].mean()
            sem_deviation = group['rank_deviation'].sem()
            n_experiments = len(group)
            
            param_stats.append({
                'param_name': param,
                'param_value': param_value,
                'mean_rank': mean_rank,
                'sem_rank': sem_rank,
                'mean_deviation': mean_deviation,
                'sem_deviation': sem_deviation,
                'n_experiments': n_experiments
            })
            
            print(f"{param_value}: {mean_rank:.2f} ± {sem_rank:.2f} (dev: {mean_deviation:.2f} ± {sem_deviation:.2f}, n={n_experiments})")
        
        param_analysis.extend(param_stats)
    
    param_df = pd.DataFrame(param_analysis)
    
    # 3. Parameter importance (average deviation from target)
    print(f"\n" + "="*50)
    print("PARAMETER IMPORTANCE (Average deviation from target rank 5)")
    print("="*50)
    
    importance_analysis = []
    for param in param_columns:
        param_groups = combined_results.groupby(param)
        deviations = []
        
        for param_value, group in param_groups:
            mean_deviation = group['rank_deviation'].mean()
            deviations.append(mean_deviation)
        
        # Calculate variance of deviations across parameter values
        deviation_variance = np.var(deviations)
        deviation_range = np.max(deviations) - np.min(deviations)
        
        importance_analysis.append({
            'param_name': param,
            'deviation_variance': deviation_variance,
            'deviation_range': deviation_range,
            'min_deviation': np.min(deviations),
            'max_deviation': np.max(deviations)
        })
        
        print(f"{param}: variance={deviation_variance:.3f}, range={deviation_range:.3f}")
    
    importance_df = pd.DataFrame(importance_analysis)
    importance_df = importance_df.sort_values('deviation_variance', ascending=False)
    
    print(f"\nParameter importance ranking (by deviation variance):")
    for i, row in importance_df.iterrows():
        print(f"{row['param_name']}: {row['deviation_variance']:.3f}")
    
    # 3b. ABSOLUTE THRESHOLD SPECIFIC ANALYSIS
    print(f"\n" + "="*50)
    print("ABSOLUTE THRESHOLD SPECIFIC ANALYSIS")
    print("="*50)
    
    absolute_results = combined_results[combined_results['threshold'] == 'absolute']
    print(f"Absolute threshold experiments: {len(absolute_results)}")
    print(f"Mean final rank: {absolute_results['final_rank'].mean():.2f} ± {absolute_results['final_rank'].sem():.2f}")
    print(f"Mean absolute deviation: {absolute_results['rank_deviation'].mean():.2f} ± {absolute_results['rank_deviation'].sem():.2f}")
    print(f"Rank range: {absolute_results['final_rank'].min():.1f} - {absolute_results['final_rank'].max():.1f}")
    print(f"Rank std: {absolute_results['final_rank'].std():.2f}")
    
    # Absolute threshold summary stats
    absolute_stats = {
        'metric': ['absolute_mean_rank', 'absolute_sem_rank', 'absolute_mean_deviation', 'absolute_sem_deviation', 
                   'absolute_min_rank', 'absolute_max_rank', 'absolute_std_rank'],
        'value': [
            absolute_results['final_rank'].mean(),
            absolute_results['final_rank'].sem(),
            absolute_results['rank_deviation'].mean(),
            absolute_results['rank_deviation'].sem(),
            absolute_results['final_rank'].min(),
            absolute_results['final_rank'].max(),
            absolute_results['final_rank'].std()
        ]
    }
    absolute_stats_df = pd.DataFrame(absolute_stats)
    print(f"\nAbsolute Threshold Summary Statistics:")
    print(absolute_stats_df.to_string(index=False))
    
    # Absolute threshold parameter importance
    print(f"\nABSOLUTE THRESHOLD PARAMETER IMPORTANCE:")
    absolute_param_columns = ['r_square_threshold', 'rank_reduction_frequency', 'rank_reduction_threshold', 'patience']
    
    absolute_importance = []
    for param in absolute_param_columns:
        param_groups = absolute_results.groupby(param)
        deviations = []
        
        for param_value, group in param_groups:
            mean_deviation = group['rank_deviation'].mean()
            deviations.append(mean_deviation)
        
        # Calculate variance of deviations across parameter values
        deviation_variance = np.var(deviations)
        deviation_range = np.max(deviations) - np.min(deviations)
        
        absolute_importance.append({
            'param_name': param,
            'deviation_variance': deviation_variance,
            'deviation_range': deviation_range,
            'min_deviation': np.min(deviations),
            'max_deviation': np.max(deviations)
        })
        
        print(f"{param}: variance={deviation_variance:.3f}, range={deviation_range:.3f}")
    
    absolute_importance_df = pd.DataFrame(absolute_importance)
    absolute_importance_df = absolute_importance_df.sort_values('deviation_variance', ascending=False)
    
    print(f"\nAbsolute threshold parameter importance ranking:")
    for i, row in absolute_importance_df.iterrows():
        print(f"{row['param_name']}: {row['deviation_variance']:.3f}")
    
    # Detailed parameter analysis for absolute threshold
    print(f"\nDETAILED ABSOLUTE THRESHOLD PARAMETER ANALYSIS:")
    absolute_param_analysis = []
    
    for param in absolute_param_columns:
        param_groups = absolute_results.groupby(param)
        
        print(f"\n{param.upper()}:")
        
        for param_value, group in param_groups:
            mean_rank = group['final_rank'].mean()
            sem_rank = group['final_rank'].sem()
            mean_deviation = group['rank_deviation'].mean()
            sem_deviation = group['rank_deviation'].sem()
            n_experiments = len(group)
            
            absolute_param_analysis.append({
                'param_name': param,
                'param_value': param_value,
                'mean_rank': mean_rank,
                'sem_rank': sem_rank,
                'mean_deviation': mean_deviation,
                'sem_deviation': sem_deviation,
                'n_experiments': n_experiments
            })
            
            print(f"  {param_value}: {mean_rank:.2f} ± {sem_rank:.2f} (dev: {mean_deviation:.2f} ± {sem_deviation:.2f}, n={n_experiments})")
    
    absolute_param_df = pd.DataFrame(absolute_param_analysis)
    
    # 4. Best parameter values for stability (closest to target rank 5)
    print(f"\n" + "="*50)
    print("BEST PARAMETER VALUES (Closest to target rank 5)")
    print("="*50)
    
    best_params = []
    for param in param_columns:
        param_groups = combined_results.groupby(param)
        best_deviation = float('inf')
        best_value = None
        best_stats = None
        
        for param_value, group in param_groups:
            mean_deviation = group['rank_deviation'].mean()
            if mean_deviation < best_deviation:
                best_deviation = mean_deviation
                best_value = param_value
                best_stats = {
                    'mean_rank': group['final_rank'].mean(),
                    'sem_rank': group['final_rank'].sem(),
                    'mean_deviation': mean_deviation,
                    'sem_deviation': group['rank_deviation'].sem(),
                    'n_experiments': len(group)
                }
        
        best_params.append({
            'param_name': param,
            'best_value': best_value,
            'best_deviation': best_deviation,
            'mean_rank': best_stats['mean_rank'],
            'sem_rank': best_stats['sem_rank'],
            'n_experiments': best_stats['n_experiments']
        })
        
        print(f"{param}: {best_value} (deviation: {best_deviation:.3f}, rank: {best_stats['mean_rank']:.2f} ± {best_stats['sem_rank']:.2f})")
    
    best_params_df = pd.DataFrame(best_params)
    
    # 5. Compare best parameter sets for relative vs absolute thresholds
    print(f"\n" + "="*50)
    print("BEST PARAMETER SETS COMPARISON: RELATIVE vs ABSOLUTE")
    print("="*50)
    
    # Find best parameter combination for relative threshold
    relative_data = combined_results[combined_results['threshold'] == 'relative']
    relative_groups = relative_data.groupby(['r_square_threshold', 'rank_reduction_frequency', 
                                           'rank_reduction_threshold', 'patience'])
    
    best_relative = None
    best_relative_deviation = float('inf')
    best_relative_stats = None
    
    for group_params, group_data in relative_groups:
        mean_deviation = group_data['rank_deviation'].mean()
        if mean_deviation < best_relative_deviation:
            best_relative_deviation = mean_deviation
            best_relative = group_params
            best_relative_stats = {
                'mean_rank': group_data['final_rank'].mean(),
                'sem_rank': group_data['final_rank'].sem(),
                'mean_deviation': mean_deviation,
                'sem_deviation': group_data['rank_deviation'].sem(),
                'n_experiments': len(group_data),
                'std_rank': group_data['final_rank'].std()
            }
    
    # Find best parameter combination for absolute threshold
    absolute_data = combined_results[combined_results['threshold'] == 'absolute']
    absolute_groups = absolute_data.groupby(['r_square_threshold', 'rank_reduction_frequency', 
                                           'rank_reduction_threshold', 'patience'])
    
    best_absolute = None
    best_absolute_deviation = float('inf')
    best_absolute_stats = None
    
    for group_params, group_data in absolute_groups:
        mean_deviation = group_data['rank_deviation'].mean()
        if mean_deviation < best_absolute_deviation:
            best_absolute_deviation = mean_deviation
            best_absolute = group_params
            best_absolute_stats = {
                'mean_rank': group_data['final_rank'].mean(),
                'sem_rank': group_data['final_rank'].sem(),
                'mean_deviation': mean_deviation,
                'sem_deviation': group_data['rank_deviation'].sem(),
                'n_experiments': len(group_data),
                'std_rank': group_data['final_rank'].std()
            }
    
    print("BEST RELATIVE THRESHOLD PARAMETER SET:")
    print(f"  r_square_threshold: {best_relative[0]}")
    print(f"  rank_reduction_frequency: {best_relative[1]}")
    print(f"  rank_reduction_threshold: {best_relative[2]}")
    print(f"  patience: {best_relative[3]}")
    print(f"  Performance:")
    print(f"    Mean rank: {best_relative_stats['mean_rank']:.3f} ± {best_relative_stats['sem_rank']:.3f}")
    print(f"    Std rank: {best_relative_stats['std_rank']:.3f}")
    print(f"    Mean deviation: {best_relative_stats['mean_deviation']:.3f} ± {best_relative_stats['sem_deviation']:.3f}")
    print(f"    N experiments: {best_relative_stats['n_experiments']}")
    
    print(f"\nBEST ABSOLUTE THRESHOLD PARAMETER SET:")
    print(f"  r_square_threshold: {best_absolute[0]}")
    print(f"  rank_reduction_frequency: {best_absolute[1]}")
    print(f"  rank_reduction_threshold: {best_absolute[2]}")
    print(f"  patience: {best_absolute[3]}")
    print(f"  Performance:")
    print(f"    Mean rank: {best_absolute_stats['mean_rank']:.3f} ± {best_absolute_stats['sem_rank']:.3f}")
    print(f"    Std rank: {best_absolute_stats['std_rank']:.3f}")
    print(f"    Mean deviation: {best_absolute_stats['mean_deviation']:.3f} ± {best_absolute_stats['sem_deviation']:.3f}")
    print(f"    N experiments: {best_absolute_stats['n_experiments']}")
    
    print(f"\n" + "="*50)
    print("FINAL VERDICT")
    print("="*50)
    
    if best_relative_deviation < best_absolute_deviation:
        winner = "RELATIVE"
        winner_deviation = best_relative_deviation
        winner_rank = best_relative_stats['mean_rank']
        winner_params = best_relative
        loser_deviation = best_absolute_deviation
        improvement = ((best_absolute_deviation - best_relative_deviation) / best_absolute_deviation) * 100
    else:
        winner = "ABSOLUTE"
        winner_deviation = best_absolute_deviation
        winner_rank = best_absolute_stats['mean_rank']
        winner_params = best_absolute
        loser_deviation = best_relative_deviation
        improvement = ((best_relative_deviation - best_absolute_deviation) / best_relative_deviation) * 100
    
    print(f"🏆 WINNER: {winner} THRESHOLD")
    print(f"Best deviation from target: {winner_deviation:.3f}")
    print(f"Best mean rank: {winner_rank:.3f} (target: 5.0)")
    print(f"Improvement over alternative: {improvement:.1f}%")
    print(f"Optimal parameters: r_square={winner_params[0]}, freq={winner_params[1]}, thresh={winner_params[2]}, patience={winner_params[3]}")
    
    # Create comparison dataframe
    comparison_data = {
        'threshold_type': ['relative_best', 'absolute_best'],
        'r_square_threshold': [best_relative[0], best_absolute[0]],
        'rank_reduction_frequency': [best_relative[1], best_absolute[1]],
        'rank_reduction_threshold': [best_relative[2], best_absolute[2]],
        'patience': [best_relative[3], best_absolute[3]],
        'mean_rank': [best_relative_stats['mean_rank'], best_absolute_stats['mean_rank']],
        'sem_rank': [best_relative_stats['sem_rank'], best_absolute_stats['sem_rank']],
        'std_rank': [best_relative_stats['std_rank'], best_absolute_stats['std_rank']],
        'mean_deviation': [best_relative_stats['mean_deviation'], best_absolute_stats['mean_deviation']],
        'sem_deviation': [best_relative_stats['sem_deviation'], best_absolute_stats['sem_deviation']],
        'n_experiments': [best_relative_stats['n_experiments'], best_absolute_stats['n_experiments']]
    }
    comparison_df = pd.DataFrame(comparison_data)
    
    # 6. Save comprehensive results
    print(f"\n" + "="*50)
    print("SAVING RESULTS")
    print("="*50)
    
    # Save summary statistics
    summary_file = "03_results/paper_results/unimodal/unimodal_sweep_summary.csv"
    
    # Combine all summary tables
    summary_data = {
        'analysis_type': [],
        'param_name': [],
        'param_value': [],
        'metric': [],
        'value': []
    }
    
    # Add overall stats
    for _, row in overall_df.iterrows():
        summary_data['analysis_type'].append('overall')
        summary_data['param_name'].append('all')
        summary_data['param_value'].append('all')
        summary_data['metric'].append(row['metric'])
        summary_data['value'].append(row['value'])
    
    # Add parameter analysis
    for _, row in param_df.iterrows():
        for metric in ['mean_rank', 'sem_rank', 'mean_deviation', 'sem_deviation']:
            summary_data['analysis_type'].append('parameter_analysis')
            summary_data['param_name'].append(row['param_name'])
            summary_data['param_value'].append(row['param_value'])
            summary_data['metric'].append(metric)
            summary_data['value'].append(row[metric])
    
    # Add importance analysis
    for _, row in importance_df.iterrows():
        for metric in ['deviation_variance', 'deviation_range']:
            summary_data['analysis_type'].append('importance')
            summary_data['param_name'].append(row['param_name'])
            summary_data['param_value'].append('all')
            summary_data['metric'].append(metric)
            summary_data['value'].append(row[metric])
    
    # Add best parameters
    for _, row in best_params_df.iterrows():
        for metric in ['best_value', 'best_deviation', 'mean_rank', 'sem_rank']:
            summary_data['analysis_type'].append('best_params')
            summary_data['param_name'].append(row['param_name'])
            summary_data['param_value'].append(row['best_value'] if metric == 'best_value' else 'best')
            summary_data['metric'].append(metric)
            summary_data['value'].append(row[metric])
    
    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(summary_file, index=False)
    print(f"Summary saved to: {summary_file}")
    
    # Save detailed parameter analysis
    param_file = "03_results/paper_results/unimodal/unimodal_parameter_analysis.csv"
    param_df.to_csv(param_file, index=False)
    print(f"Parameter analysis saved to: {param_file}")
    
    # Save importance analysis
    importance_file = "03_results/paper_results/unimodal/unimodal_importance_analysis.csv"
    importance_df.to_csv(importance_file, index=False)
    print(f"Importance analysis saved to: {importance_file}")
    
    # Save absolute threshold specific analysis
    absolute_stats_file = "03_results/paper_results/unimodal/unimodal_absolute_summary.csv"
    absolute_stats_df.to_csv(absolute_stats_file, index=False)
    print(f"Absolute threshold summary saved to: {absolute_stats_file}")
    
    absolute_importance_file = "03_results/paper_results/unimodal/unimodal_absolute_importance.csv"
    absolute_importance_df.to_csv(absolute_importance_file, index=False)
    print(f"Absolute threshold importance saved to: {absolute_importance_file}")
    
    absolute_param_file = "03_results/paper_results/unimodal/unimodal_absolute_parameters.csv"
    absolute_param_df.to_csv(absolute_param_file, index=False)
    print(f"Absolute threshold parameters saved to: {absolute_param_file}")
    
    # Save best parameters
    best_params_file = "03_results/paper_results/unimodal/unimodal_best_parameters.csv"
    best_params_df.to_csv(best_params_file, index=False)
    print(f"Best parameters saved to: {best_params_file}")
    
    # Save threshold comparison
    comparison_file = "03_results/paper_results/unimodal/unimodal_threshold_comparison.csv"
    comparison_df.to_csv(comparison_file, index=False)
    print(f"Threshold comparison saved to: {comparison_file}")
    
    # Save enhanced original results with computed metrics
    enhanced_results_file = "03_results/paper_results/unimodal/unimodal_enhanced_results.csv"
    combined_results.to_csv(enhanced_results_file, index=False)
    print(f"Enhanced results saved to: {enhanced_results_file}")
    
    print(f"\nAnalysis complete! Check the saved files for detailed results.")
    
    return {
        'overall_stats': overall_df,
        'parameter_analysis': param_df,
        'importance_analysis': importance_df,
        'absolute_stats': absolute_stats_df,
        'absolute_importance': absolute_importance_df,
        'absolute_parameters': absolute_param_df,
        'best_parameters': best_params_df,
        'threshold_comparison': comparison_df,
        'combined_results': combined_results
    }

if __name__ == "__main__":
    # Change to the project directory
    import sys
    from pathlib import Path
    project_root = Path(__file__).parent.parent.parent.absolute()
    os.chdir(str(project_root))
    
    # Run the analysis
    results = analyze_unimodal_sweep_results()