# experiments.py
# Comprehensive experimental framework for U-RankMOEA
# Includes benchmark comparisons, statistical analysis, and result visualization

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import wilcoxon, mannwhitneyu, kruskal
from sklearn.preprocessing import StandardScaler
import time
import os
import pickle
import json
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import warnings
warnings.filterwarnings("ignore")

from u_rankmoea import URankMOEA, CFG
from benchmark_problems import *
from metrics import *
from visualization import *
from baseline_algorithms import *

class ExperimentRunner:
    """Comprehensive experiment runner for U-RankMOEA evaluation"""
    
    def __init__(self, output_dir="experiments_output"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True, parents=True)
        
        # Results storage
        self.results = {}
        self.statistical_results = {}
        
        # Default experimental configurations
        self.problem_configs = {
            'ZDT1': {'D': [30, 50, 100, 200], 'M': 2},
            'ZDT2': {'D': [30, 50, 100, 200], 'M': 2},
            'ZDT3': {'D': [30, 50, 100, 200], 'M': 2}, 
            'ZDT4': {'D': [30, 50, 100, 200], 'M': 2},
            'ZDT6': {'D': [30, 50, 100, 200], 'M': 2},
            'DTLZ1': {'D': [30, 50, 100, 200], 'M': [2, 3]},
            'DTLZ2': {'D': [30, 50, 100, 200], 'M': [2, 3]},
            'DTLZ3': {'D': [30, 50, 100, 200], 'M': [2, 3]},
            'DTLZ4': {'D': [30, 50, 100, 200], 'M': [2, 3]},
            'DTLZ7': {'D': [30, 50, 100, 200], 'M': [2, 3]},
        }
        
        self.algorithms = {
            'U-RankMOEA': URankMOEA,
            'NSGA-II': NSGAII_Baseline,
            'MOEA/D': MOEAD_Baseline, 
            'Random': RandomBaseline,
            'GP-NSGA-II': GP_NSGA_II_Baseline,
        }
        
        self.n_runs = 20  # Statistical significance
        self.confidence_level = 0.05
        
    def create_algorithm_config(self, problem_name, D, M, maxFEs=300):
        """Create algorithm-specific configuration"""
        base_cfg = CFG.copy()
        base_cfg.update({
            'D': D,
            'M': M, 
            'maxFEs': maxFEs,
            'N_init': max(50, min(100, 2*D)),  # Adaptive initial samples
            'batch_size': max(4, min(8, maxFEs//50)),  # Adaptive batch size
        })
        
        # Problem-specific adjustments
        if 'DTLZ' in problem_name:
            base_cfg['N_init'] = max(100, min(200, 3*D))
        
        if D >= 100:
            base_cfg['screening_pool'] = 2000
            base_cfg['expensive_eval'] = 200
            base_cfg['dgp_epochs'] = 60
            base_cfg['clf_epochs'] = 40
        
        return base_cfg
    
    def run_single_experiment(self, algorithm_name, problem_name, D, M, run_id, maxFEs=300):
        """Run a single algorithm on a problem instance"""
        try:
            print(f"Running {algorithm_name} on {problem_name}({D}D, {M}M) - Run {run_id}")
            
            # Set random seed for reproducibility
            seed = 42 + run_id * 1000
            np.random.seed(seed)
            
            # Get problem function
            problem_func = get_problem_function(problem_name, M)
            true_front = get_true_pareto_front(problem_name, M)
            
            # Create algorithm configuration
            cfg = self.create_algorithm_config(problem_name, D, M, maxFEs)
            cfg['seed'] = seed
            
            # Initialize and run algorithm
            start_time = time.time()
            
            if algorithm_name == 'U-RankMOEA':
                algorithm = URankMOEA(cfg, problem_func)
                final_X, final_Y = algorithm.run()
                hv_history = algorithm.hv_history
                igd_history = algorithm.igd_history
            else:
                # Baseline algorithm
                AlgorithmClass = self.algorithms[algorithm_name]
                algorithm = AlgorithmClass(cfg, problem_func)
                final_X, final_Y, hv_history, igd_history = algorithm.run()
            
            wall_time = time.time() - start_time
            
            # Calculate final metrics
            final_front = nondominated_frontpoints(final_Y)
            metrics = calculate_all_metrics(final_Y, true_front, problem_name)
            
            # Store detailed results
            result = {
                'algorithm': algorithm_name,
                'problem': problem_name,
                'D': D,
                'M': M,
                'run_id': run_id,
                'final_X': final_X,
                'final_Y': final_Y,
                'final_front': final_front,
                'hv_history': hv_history,
                'igd_history': igd_history,
                'wall_time': wall_time,
                'FEs': len(final_Y),
                'metrics': metrics,
                'success': True
            }
            
            print(f"  Completed: HV={metrics['HV']:.4f}, IGD={metrics['IGD']:.4f}, Time={wall_time:.1f}s")
            return result
            
        except Exception as e:
            print(f"  Failed: {str(e)}")
            return {
                'algorithm': algorithm_name,
                'problem': problem_name, 
                'D': D,
                'M': M,
                'run_id': run_id,
                'success': False,
                'error': str(e)
            }
    
    def run_comparison_study(self, problems_subset=None, algorithms_subset=None, 
                           parallel=True, max_workers=4):
        """Run comprehensive comparison study"""
        print("=== Starting Comparison Study ===")
        
        if problems_subset is None:
            problems_subset = list(self.problem_configs.keys())
        if algorithms_subset is None:
            algorithms_subset = list(self.algorithms.keys())
        
        # Generate all experiment configurations
        experiments = []
        for problem_name in problems_subset:
            if problem_name not in self.problem_configs:
                continue
                
            problem_config = self.problem_configs[problem_name]
            dimensions = problem_config['D']
            objectives = problem_config['M']
            
            if isinstance(objectives, int):
                objectives = [objectives]
            
            for D in dimensions:
                for M in objectives:
                    for algorithm_name in algorithms_subset:
                        for run_id in range(self.n_runs):
                            experiments.append((algorithm_name, problem_name, D, M, run_id))
        
        print(f"Total experiments to run: {len(experiments)}")
        
        # Run experiments
        if parallel and len(experiments) > 10:
            results = self._run_parallel_experiments(experiments, max_workers)
        else:
            results = self._run_sequential_experiments(experiments)
        
        # Organize results
        self._organize_results(results)
        
        # Save raw results
        self.save_results()
        
        print("=== Comparison Study Completed ===")
    
    def _run_parallel_experiments(self, experiments, max_workers):
        """Run experiments in parallel"""
        results = []
        
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            # Submit all jobs
            future_to_exp = {
                executor.submit(self.run_single_experiment, *exp): exp 
                for exp in experiments
            }
            
            # Collect results as they complete
            for future in as_completed(future_to_exp):
                result = future.result()
                results.append(result)
                
                # Progress update
                if len(results) % 20 == 0:
                    print(f"Progress: {len(results)}/{len(experiments)} experiments completed")
        
        return results
    
    def _run_sequential_experiments(self, experiments):
        """Run experiments sequentially"""
        results = []
        
        for i, exp in enumerate(experiments):
            result = self.run_single_experiment(*exp)
            results.append(result)
            
            if (i + 1) % 10 == 0:
                print(f"Progress: {i+1}/{len(experiments)} experiments completed")
        
        return results
    
    def _organize_results(self, results):
        """Organize experiment results by algorithm and problem"""
        self.results = {}
        
        for result in results:
            if not result['success']:
                continue
            
            algorithm = result['algorithm']
            problem = result['problem']
            key = f"{problem}_{result['D']}D_{result['M']}M"
            
            if algorithm not in self.results:
                self.results[algorithm] = {}
            
            if key not in self.results[algorithm]:
                self.results[algorithm][key] = []
            
            self.results[algorithm][key].append(result)
    
    def perform_statistical_analysis(self):
        """Perform comprehensive statistical analysis"""
        print("=== Performing Statistical Analysis ===")
        
        # Collect metrics for all algorithm-problem combinations
        comparison_data = {}
        
        for algorithm in self.results:
            for problem_key in self.results[algorithm]:
                results_list = self.results[algorithm][problem_key]
                
                if len(results_list) < 5:  # Need sufficient samples
                    continue
                
                key = f"{algorithm}_{problem_key}"
                metrics_data = {
                    'HV': [r['metrics']['HV'] for r in results_list],
                    'IGD': [r['metrics']['IGD'] for r in results_list],
                    'Spacing': [r['metrics']['Spacing'] for r in results_list],
                    'FrontSize': [len(r['final_front']) for r in results_list],
                    'WallTime': [r['wall_time'] for r in results_list]
                }
                comparison_data[key] = metrics_data
        
        # Statistical tests between algorithms
        self.statistical_results = {}
        algorithms = list(self.results.keys())
        
        for problem_key in set([k.split('_', 1)[1] for k in comparison_data.keys()]):
            self.statistical_results[problem_key] = {}
            
            # Collect data for this problem across algorithms
            problem_data = {}
            for alg in algorithms:
                key = f"{alg}_{problem_key}"
                if key in comparison_data:
                    problem_data[alg] = comparison_data[key]
            
            if len(problem_data) < 2:
                continue
            
            # Pairwise statistical tests
            for metric in ['HV', 'IGD', 'Spacing']:
                self.statistical_results[problem_key][metric] = {}
                
                for i, alg1 in enumerate(problem_data.keys()):
                    for alg2 in list(problem_data.keys())[i+1:]:
                        if metric in problem_data[alg1] and metric in problem_data[alg2]:
                            data1 = problem_data[alg1][metric]
                            data2 = problem_data[alg2][metric]
                            
                            # Wilcoxon rank-sum test
                            try:
                                statistic, p_value = mannwhitneyu(data1, data2, alternative='two-sided')
                                
                                self.statistical_results[problem_key][metric][f"{alg1}_vs_{alg2}"] = {
                                    'statistic': statistic,
                                    'p_value': p_value,
                                    'significant': p_value < self.confidence_level,
                                    'effect_size': abs(np.median(data1) - np.median(data2)),
                                    'median1': np.median(data1),
                                    'median2': np.median(data2),
                                    'better': alg1 if (np.median(data1) > np.median(data2) and metric == 'HV') or \
                                                    (np.median(data1) < np.median(data2) and metric in ['IGD', 'Spacing']) \
                                              else alg2
                                }
                            except Exception as e:
                                print(f"Statistical test failed for {alg1} vs {alg2} on {metric}: {e}")
        
        print("Statistical analysis completed")
    
    def generate_summary_tables(self):
        """Generate summary tables for publication"""
        print("=== Generating Summary Tables ===")
        
        # Main results table
        summary_data = []
        
        for algorithm in self.results:
            for problem_key in self.results[algorithm]:
                results_list = self.results[algorithm][problem_key]
                
                if len(results_list) == 0:
                    continue
                
                # Calculate statistics
                hv_values = [r['metrics']['HV'] for r in results_list]
                igd_values = [r['metrics']['IGD'] for r in results_list]
                spacing_values = [r['metrics']['Spacing'] for r in results_list]
                time_values = [r['wall_time'] for r in results_list]
                
                summary_data.append({
                    'Algorithm': algorithm,
                    'Problem': problem_key,
                    'HV_mean': np.mean(hv_values),
                    'HV_std': np.std(hv_values),
                    'IGD_mean': np.mean(igd_values),
                    'IGD_std': np.std(igd_values),
                    'Spacing_mean': np.mean(spacing_values),
                    'Spacing_std': np.std(spacing_values),
                    'Time_mean': np.mean(time_values),
                    'Time_std': np.std(time_values),
                    'Runs': len(results_list)
                })
        
        # Create DataFrame and save
        summary_df = pd.DataFrame(summary_data)
        summary_df = summary_df.round(6)
        
        # Save tables
        summary_df.to_csv(self.output_dir / 'summary_table.csv', index=False)
        
        # Create publication-ready LaTeX table
        self._create_latex_table(summary_df)
        
        print(f"Summary tables saved to {self.output_dir}")
        return summary_df
    
    def _create_latex_table(self, df):
        """Create publication-ready LaTeX table"""
        # Group by problem for better formatting
        problems = df['Problem'].unique()
        algorithms = df['Algorithm'].unique()
        
        latex_content = []
        latex_content.append("\\begin{table}[htbp]")
        latex_content.append("\\centering")
        latex_content.append("\\caption{Performance Comparison Results}")
        latex_content.append("\\label{tab:results}")
        
        # Table header
        header = "\\begin{tabular}{l|" + "c" * len(algorithms) + "}"
        latex_content.append(header)
        latex_content.append("\\hline")
        
        alg_header = "Problem & " + " & ".join(algorithms) + " \\\\"
        latex_content.append(alg_header)
        latex_content.append("\\hline")
        
        # Table content
        for problem in sorted(problems):
            problem_data = df[df['Problem'] == problem]
            
            # HV row
            hv_row = f"{problem} (HV)"
            for alg in algorithms:
                alg_data = problem_data[problem_data['Algorithm'] == alg]
                if len(alg_data) > 0:
                    mean_val = alg_data['HV_mean'].iloc[0]
                    std_val = alg_data['HV_std'].iloc[0]
                    hv_row += f" & ${mean_val:.3f}\\pm{std_val:.3f}$"
                else:
                    hv_row += " & --"
            hv_row += " \\\\"
            latex_content.append(hv_row)
            
            # IGD row
            igd_row = f"{problem} (IGD)"
            for alg in algorithms:
                alg_data = problem_data[problem_data['Algorithm'] == alg]
                if len(alg_data) > 0:
                    mean_val = alg_data['IGD_mean'].iloc[0]
                    std_val = alg_data['IGD_std'].iloc[0]
                    igd_row += f" & ${mean_val:.4f}\\pm{std_val:.4f}$"
                else:
                    igd_row += " & --"
            igd_row += " \\\\"
            latex_content.append(igd_row)
            latex_content.append("\\hline")
        
        latex_content.append("\\end{tabular}")
        latex_content.append("\\end{table}")
        
        # Save LaTeX table
        with open(self.output_dir / 'results_table.tex', 'w') as f:
            f.write('\n'.join(latex_content))
    
    def create_visualizations(self):
        """Create comprehensive visualization suite"""
        print("=== Creating Visualizations ===")
        
        vis = VisualizationSuite(self.results, self.output_dir)
        
        # Performance comparison plots
        vis.create_performance_comparison()
        vis.create_convergence_plots()
        vis.create_pareto_front_comparisons()
        vis.create_statistical_significance_heatmap(self.statistical_results)
        
        # Algorithm analysis plots
        vis.create_scalability_analysis()
        vis.create_uncertainty_analysis()
        vis.create_component_contribution_analysis()
        
        print(f"Visualizations saved to {self.output_dir}")
    
    def save_results(self):
        """Save all experimental results"""
        # Save detailed results
        with open(self.output_dir / 'detailed_results.pkl', 'wb') as f:
            pickle.dump(self.results, f)
        
        # Save statistical results
        with open(self.output_dir / 'statistical_results.pkl', 'wb') as f:
            pickle.dump(self.statistical_results, f)
        
        # Save configuration
        config_summary = {
            'problem_configs': self.problem_configs,
            'algorithms': list(self.algorithms.keys()),
            'n_runs': self.n_runs,
            'confidence_level': self.confidence_level
        }
        
        with open(self.output_dir / 'experiment_config.json', 'w') as f:
            json.dump(config_summary, f, indent=2)
    
    def load_results(self):
        """Load previously saved results"""
        try:
            with open(self.output_dir / 'detailed_results.pkl', 'rb') as f:
                self.results = pickle.load(f)
            
            with open(self.output_dir / 'statistical_results.pkl', 'rb') as f:
                self.statistical_results = pickle.load(f)
            
            print(f"Results loaded from {self.output_dir}")
            return True
        except FileNotFoundError:
            print("No saved results found")
            return False
    
    def run_ablation_study(self):
        """Run ablation study to analyze component contributions"""
        print("=== Running Ablation Study ===")
        
        # Define ablation variants
        ablation_configs = {
            'U-RankMOEA-Full': CFG,
            'U-RankMOEA-NoClf': {**CFG, 'clf_ensembles': 0},  # No classifier
            'U-RankMOEA-NoDGP': {**CFG, 'dgp_ensembles': 0},  # No Deep GP  
            'U-RankMOEA-NoAcq': {**CFG, 'acq_epochs': 0},      # No acquisition
            'U-RankMOEA-NoUncertainty': {**CFG, 'uncertainty_weight': 0},  # No uncertainty
        }
        
        # Test problems for ablation
        test_problems = ['ZDT1', 'ZDT2', 'DTLZ2']
        test_dimensions = [30, 100]
        
        ablation_results = {}
        
        for variant_name, config in ablation_configs.items():
            ablation_results[variant_name] = {}
            
            for problem_name in test_problems:
                for D in test_dimensions:
                    key = f"{problem_name}_{D}D_2M"
                    ablation_results[variant_name][key] = []
                    
                    for run_id in range(10):  # Fewer runs for ablation
                        result = self.run_single_experiment(
                            variant_name, problem_name, D, 2, run_id
                        )
                        if result['success']:
                            ablation_results[variant_name][key].append(result)
        
        # Save ablation results
        with open(self.output_dir / 'ablation_results.pkl', 'wb') as f:
            pickle.dump(ablation_results, f)
        
        # Create ablation analysis
        self._analyze_ablation_results(ablation_results)
        
        print("Ablation study completed")
    
    def _analyze_ablation_results(self, ablation_results):
        """Analyze ablation study results"""
        # Calculate performance degradation for each component
        ablation_analysis = {}
        
        for problem_key in set().union(*[list(v.keys()) for v in ablation_results.values()]):
            if 'U-RankMOEA-Full' not in ablation_results:
                continue
            
            if problem_key not in ablation_results['U-RankMOEA-Full']:
                continue
            
            full_results = ablation_results['U-RankMOEA-Full'][problem_key]
            if len(full_results) == 0:
                continue
            
            full_hv = np.mean([r['metrics']['HV'] for r in full_results])
            full_igd = np.mean([r['metrics']['IGD'] for r in full_results])
            
            ablation_analysis[problem_key] = {'Full': {'HV': full_hv, 'IGD': full_igd}}
            
            for variant in ablation_results:
                if variant == 'U-RankMOEA-Full':
                    continue
                
                if problem_key in ablation_results[variant]:
                    variant_results = ablation_results[variant][problem_key]
                    if len(variant_results) > 0:
                        variant_hv = np.mean([r['metrics']['HV'] for r in variant_results])
                        variant_igd = np.mean([r['metrics']['IGD'] for r in variant_results])
                        
                        hv_degradation = (full_hv - variant_hv) / full_hv * 100
                        igd_degradation = (variant_igd - full_igd) / full_igd * 100
                        
                        ablation_analysis[problem_key][variant] = {
                            'HV': variant_hv,
                            'IGD': variant_igd,
                            'HV_degradation': hv_degradation,
                            'IGD_degradation': igd_degradation
                        }
        
        # Save analysis
        with open(self.output_dir / 'ablation_analysis.json', 'w') as f:
            json.dump(ablation_analysis, f, indent=2)
        
        # Create ablation visualization
        self._visualize_ablation_results(ablation_analysis)
    
    def _visualize_ablation_results(self, ablation_analysis):
        """Visualize ablation study results"""
        # Component contribution bar plot
        components = ['NoClf', 'NoDGP', 'NoAcq', 'NoUncertainty']
        component_labels = ['No Classifier', 'No Deep GP', 'No Acquisition', 'No Uncertainty']
        
        hv_degradations = []
        igd_degradations = []
        
        for component in components:
            variant_name = f'U-RankMOEA-{component}'
            
            hv_deg_values = []
            igd_deg_values = []
            
            for problem_key in ablation_analysis:
                if variant_name in ablation_analysis[problem_key]:
                    hv_deg_values.append(ablation_analysis[problem_key][variant_name]['HV_degradation'])
                    igd_deg_values.append(ablation_analysis[problem_key][variant_name]['IGD_degradation'])
            
            hv_degradations.append(np.mean(hv_deg_values) if hv_deg_values else 0)
            igd_degradations.append(np.mean(igd_deg_values) if igd_deg_values else 0)
        
        # Create bar plot
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        bars1 = ax1.bar(component_labels, hv_degradations, color='steelblue', alpha=0.7)
        ax1.set_ylabel('HV Degradation (%)')
        ax1.set_title('Component Contribution - Hypervolume')
        ax1.set_ylim(0, max(hv_degradations) * 1.1)
        
        for bar, val in zip(bars1, hv_degradations):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{val:.1f}%', ha='center', va='bottom')
        
        bars2 = ax2.bar(component_labels, igd_degradations, color='coral', alpha=0.7)
        ax2.set_ylabel('IGD Degradation (%)')
        ax2.set_title('Component Contribution - IGD')
        ax2.set_ylim(0, max(igd_degradations) * 1.1)
        
        for bar, val in zip(bars2, igd_degradations):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{val:.1f}%', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(self.output_dir / 'ablation_study.png', dpi=300, bbox_inches='tight')
        plt.close()

def main():
    """Main experimental workflow"""
    print("=== U-RankMOEA Experimental Framework ===")
    
    # Initialize experiment runner
    runner = ExperimentRunner("experiments_output")
    
    # Quick test run
    print("\n1. Running quick test...")
    test_problems = ['ZDT1', 'DTLZ2'] 
    test_algorithms = ['U-RankMOEA', 'Random']
    
    runner.problem_configs = {
        'ZDT1': {'D': [30], 'M': 2},
        'DTLZ2': {'D': [30], 'M': 2}
    }
    runner.n_runs = 3  # Quick test
    
    runner.run_comparison_study(
        problems_subset=test_problems,
        algorithms_subset=test_algorithms,
        parallel=False
    )
    
    # Perform analysis
    print("\n2. Performing statistical analysis...")
    runner.perform_statistical_analysis()
    
    print("\n3. Generating summary tables...")
    summary_df = runner.generate_summary_tables()
    print(summary_df)
    
    print("\n4. Creating visualizations...")
    runner.create_visualizations()
    
    # Uncomment for full study
    """
    print("\n5. Running full comparison study...")
    runner.n_runs = 20
    runner.run_comparison_study(parallel=True, max_workers=4)
    
    print("\n6. Running ablation study...")
    runner.run_ablation_study()
    """
    
    print(f"\nExperiments completed. Results saved to: {runner.output_dir}")

if __name__ == "__main__":
    main()