import os
import yaml
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any, List, Tuple, Optional
from collections import defaultdict
import torch
import warnings
warnings.filterwarnings('ignore')

import sys


from MOCO.problems import BiObjectiveTSP, MultiObjectiveKnapsack, TriObjectiveTSP
from MOCO.evaluation import MOCOEvaluator

# from our_method_MW_UCB_fixed_OCO_accelerated_universal_agnostic_v3_ablation import (
#     CachedAdvancedBiKPWrapper as UCBWrapper
# )
from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import CachedAdvancedBiKPWrapper as UCBWrapper

# from our_method_MW_Thompson_fixed_OCO_accelerated_universal_agnostic import (
#     CachedAdvancedBiKPWrapper as ThompsonWrapper
# )
from our_method_dl_Thompson_variant import (
    CachedAdvancedBiKPWrapper as ThompsonWrapper
)

def setup_icml_style():
    """Configure matplotlib for ICML paper style"""
    
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("husl")
    
    plt.rcParams.update({
        # Font settings - use sans-serif for ICML
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica', 'Liberation Sans'],
        'font.size': 9,
        'axes.labelsize': 10,
        'axes.titlesize': 11,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        
        # Figure settings
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        
        # Layout
        'figure.constrained_layout.use': True,
        'figure.autolayout': False,
        
        # Grid and spines
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 0.8,
        
        # Lines and markers
        'lines.linewidth': 1.5,
        'lines.markersize': 6,
        
        # Legend
        'legend.frameon': True,
        'legend.framealpha': 0.9,
        'legend.edgecolor': '0.8',
        
        # Other
        'axes.axisbelow': True,
    })

# Call it once at module level
setup_icml_style()


PROBLEM_REGISTRY = {
    'BiTSP': {
        'class': BiObjectiveTSP,
        'ref_type': 'BiTSP',
        'sizes': {
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100},
        },
        'default_ref': lambda size: (20, 20) if size == 20 else (35, 35) if size == 50 else (65, 65)
    },
    'BiKP': {
        'class': MultiObjectiveKnapsack,
        'ref_type': 'BiKP',
        'sizes': {
            'small': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5},
            'medium': {'n_items': 100, 'n_objectives': 2, 'capacity': 25.0},
            'large': {'n_items': 200, 'n_objectives': 2, 'capacity': 50.0},
        },
        'default_ref': lambda size: (5, 5) if size == 50 else (20, 20) if size == 100 else (30, 30)
    },
}


def ablation_overlap_vs_decomposition(
    algorithm_classes: Dict[str, Any],
    problem_type: str = 'BiTSP',
    problem_size: str = 'medium',
    decomposition_sizes: List[int] = None,
    overlap_ratios: List[float] = None,
    num_runs: int = 5,
    output_dir: str = 'ablation_results',
    plot_metrics: List[str] = None,  # NEW: Choose which metrics to plot
    baseline_solvers: Dict[str, Dict[str, float]] = None,  # NEW: Baseline solver results
):
    """
    Ablation: Test overlap % (x-axis) for different decomposition sizes (colored lines)
    
    Creates ONE figure with 3 subplots based on selected metrics
    
    Parameters:
    -----------
    algorithm_classes : Dict[str, Any]
        Algorithms to test
    problem_type : str
        'BiTSP' or 'BiKP'
    problem_size : str
        'small', 'medium', or 'large'
    decomposition_sizes : List[int]
        Different decomposition sizes (each becomes a colored line)
    overlap_ratios : List[float]
        Overlap percentages to test (x-axis points)
    num_runs : int
        Number of runs per configuration
    output_dir : str
        Directory to save results
    plot_metrics : List[str]
        Metrics to plot. Options:
        ['hypervolume', 'runtime', 'solutions', 'tour_length', 'convergence']
        Default: ['hypervolume', 'runtime', 'tour_length']
    baseline_solvers : Dict[str, Dict[str, float]]
        Optional baseline solver results to show as reference lines.
        Example: {'WS': {'hypervolume': 0.5, 'runtime': 10.0}}
    """
    
    if plot_metrics is None:
        plot_metrics = ['hypervolume', 'runtime', 'tour_length']
    
    print("\n" + "="*80)
    print("ABLATION: OVERLAP % vs PERFORMANCE (Multiple Decomposition Sizes)")
    print("="*80)
    
    if problem_type not in PROBLEM_REGISTRY:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    problem_config = PROBLEM_REGISTRY[problem_type]
    problem_class = problem_config['class']
    problem_params = problem_config['sizes'][problem_size]
    
    # Get actual problem size
    if 'n_cities' in problem_params:
        actual_size = problem_params['n_cities']
    elif 'n_items' in problem_params:
        actual_size = problem_params['n_items']
    else:
        actual_size = 50
    
    # Auto-generate decomposition sizes
    if decomposition_sizes is None:
        decomposition_sizes = [
            max(5, actual_size // 10),   # 10%
            max(10, actual_size // 5),   # 20%
            max(15, actual_size // 3),   # 33%
            max(20, actual_size // 2),   # 50%
            max(25, 3 * actual_size // 5),  # 60%
            max(30, 2 * actual_size // 3)  # 67%
        ]
        decomposition_sizes = sorted(list(set(decomposition_sizes)))
    
    # Default overlap ratios
    if overlap_ratios is None:
        overlap_ratios = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
    
    print(f"\nProblem: {problem_type} ({problem_size}), size={actual_size}")
    print(f"Algorithms: {list(algorithm_classes.keys())}")
    print(f"Decomposition sizes: {decomposition_sizes}")
    print(f"Overlap ratios: {overlap_ratios}")
    print(f"Runs per config: {num_runs}")
    print(f"Plotting metrics: {plot_metrics}")
    
    # Get reference point
    ref_point = problem_config['default_ref'](actual_size)
    print(f"Reference point: {ref_point}")
    
    # Results storage: results[alg_name][decomp_size][overlap_ratio] = list of runs
    results = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    
    # Run experiments
    for alg_name, alg_class in algorithm_classes.items():
        print(f"\n{'-'*60}")
        print(f"Testing {alg_name}")
        print(f"{'-'*60}")
        
        for decomp_size in decomposition_sizes:
            print(f"\n  Decomposition size: {decomp_size}")
            
            for overlap_ratio in overlap_ratios:
                overlap = max(0, int(decomp_size * overlap_ratio))
                
                print(f"    Overlap: {overlap_ratio*100:.0f}% ({overlap} abs)")
                
                #------------------------------ UCB Base params ------------------------------
                # base_params = {
                #     'learning_rate': 0.5,
                #     'initial_temperature': 1.0,
                #     'temp_decay': 0.98,
                #     'hybrid_ratio': 0.5,
                #     'adaptive_hybrid': True,
                #     'decomposition_size': decomp_size,
                #     'overlap': overlap,
                #     'max_iterations': 80,
                #     'nb_rounds': 5,
                #     'patience': 30,
                #     'use_lagrangian': True,
                #     'use_ftrl': True,
                #     'dual_step_size': 1.0,
                #     'use_accelerated_dual': True,
                #     'use_diminishing_overlap': False,
                #     'n_weight_vectors': 20,
                # }
                base_params = {
                    'learning_rate': 0.1,
                    'initial_temperature': 1.0,
                    'temp_decay': 0.98,
                    'hybrid_ratio': 0.5,
                    'adaptive_hybrid': True,
                    'decomposition_size': decomp_size,
                    'overlap': overlap,
                    'max_iterations': 80,
                    'nb_rounds': 15,
                    'patience': 30,
                    'use_lagrangian': True,
                    'use_ftrl': True,
                    'dual_step_size': 1.0,
                    'use_accelerated_dual': True,
                    'use_diminishing_overlap': False,
                    'n_weight_vectors': 30,
                }

                #------------------------------ Thompson Base params ------------------------------
                # base_params = {
                #     'learning_rate': 0.1,
                #     'initial_temperature': 2.0,
                #     'temp_decay': 0.995,
                #     'hybrid_ratio': 0.7,
                #     'adaptive_hybrid': True,
                #     'decomposition_size': decomp_size,
                #     'overlap': overlap,
                #     'max_iterations': 100,
                #     'nb_rounds': 30,
                #     'patience': 30,
                #     'use_lagrangian': True,
                #     'use_ftrl': True,
                #     'dual_step_size': 1.0,
                #     'use_accelerated_dual': True,
                #     'use_diminishing_overlap': True,
                #     'n_weight_vectors': 30,
                # }
                
                # Algorithm-specific
                if 'UCB' in alg_name:
                    base_params['ucb_coefficient'] = 3.0
                elif 'Thompson' in alg_name:
                    base_params['hybrid_ratio'] = 0.3
                    base_params['temp_decay'] = 0.995
                
                # Run multiple times
                for run in range(num_runs):
                    try:
                        evaluator = MOCOEvaluator(reference_point=ref_point)
                        result = evaluator.evaluate_algorithm(
                            algorithm_class=alg_class,
                            problem_class=problem_class,
                            algorithm_name=f"{alg_name}",
                            parameters=base_params,
                            problem_params=problem_params,
                            num_runs=1
                        )
                        
                        # Calculate tour length (for TSP) or knapsack value
                        tour_length = None
                        if problem_type == 'BiTSP' and result.objectives:
                            # Average of both objectives for TSP
                            tour_length = np.mean([obj[0] + obj[1] for obj in result.objectives])
                        elif problem_type == 'BiKP' and result.objectives:
                            # Negative average value for knapsack (we want to maximize)
                            tour_length = -np.mean([obj[0] + obj[1] for obj in result.objectives])
                        
                        results[alg_name][decomp_size][overlap_ratio].append({
                            'hypervolume': result.hypervolume,
                            'runtime': result.runtime,
                            'num_solutions': result.num_nondominated,
                            'tour_length': tour_length,
                            'objectives': result.objectives
                        })
                        
                        tour_str = f"{tour_length:.2f}" if tour_length is not None else "N/A"
                        print(f"      Run {run+1}/{num_runs} - HV: {result.hypervolume:.4f}, "
                            f"TourLen: {tour_str}")
                        
                    except Exception as e:
                        print(f"      Run {run+1} failed: {e}")
                        import traceback
                        traceback.print_exc()
    
    # Save results
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    os.makedirs(output_dir, exist_ok=True)
    
    results_file = os.path.join(output_dir, f'ablation_overlap_decomp_{problem_type}_{timestamp}.yaml')
    with open(results_file, 'w') as f:
        yaml.dump({
            'problem_type': problem_type,
            'problem_size': problem_size,
            'actual_size': actual_size,
            'decomposition_sizes': decomposition_sizes,
            'overlap_ratios': overlap_ratios,
            'num_runs': num_runs,
            'algorithms': list(algorithm_classes.keys()),
            'plot_metrics': plot_metrics,
            'results': {alg: {str(d): {str(o): res for o, res in ov.items()} 
                             for d, ov in dec.items()} 
                       for alg, dec in results.items()}
        }, f)
    
    print(f"\nResults saved to: {results_file}")
    
    # NEW: Save to CSV
    csv_individual, csv_aggregated = save_results_to_csv(
        results, decomposition_sizes, overlap_ratios,
        problem_type, problem_size, output_dir, timestamp
    )

    # Generate plots
    plot_overlap_decomposition(
        results, decomposition_sizes, overlap_ratios,
        problem_type, problem_size, output_dir, timestamp, actual_size,
        plot_metrics, baseline_solvers
    )
    
    return results


def plot_overlap_decomposition(
    results, decomposition_sizes, overlap_ratios,
    problem_type, problem_size, output_dir, timestamp, actual_size,
    plot_metrics, baseline_solvers=None
):
    """
    Create ONE figure with 3 subplots based on selected metrics

    Parameters:
    -----------
    baseline_solvers : Dict[str, Dict[str, float]], optional
        Dictionary mapping solver names to their metric values.
        Example: {
            'WS': {'hypervolume': 0.5, 'runtime': 10.0},
            'NSGA-II': {'hypervolume': 0.48, 'runtime': 8.5}
        }
        These will be shown as horizontal dashed reference lines.
    """
    
    print("\nGenerating ICML-style plots...")
    
    # Define metric configurations
    metric_configs = {
        'hypervolume': {
            'column': 'Hypervolume',
            'ylabel': 'Hypervolume',
            'marker': 'o',
            'higher_better': True
        },
        'runtime': {
            'column': 'Runtime (s)',
            'ylabel': 'Runtime (seconds)',
            'marker': 's',
            'higher_better': False
        },
        'solutions': {
            'column': 'Solutions',
            'ylabel': 'Number of Solutions',
            'marker': '^',
            'higher_better': True
        },
        'tour_length': {
            'column': 'Tour Length',
            'ylabel': 'Average Tour Length' if problem_type == 'BiTSP' else 'Average Value',
            'marker': 'd',
            'higher_better': False if problem_type == 'BiTSP' else True
        },
    }
    
    for alg_name, alg_results in results.items():
        # Prepare data
        data = []
        for decomp_size in decomposition_sizes:
            for overlap_ratio in overlap_ratios:
                if overlap_ratio in alg_results[decomp_size]:
                    for run_result in alg_results[decomp_size][overlap_ratio]:
                        data.append({
                            'Decomposition Size': decomp_size,
                            'Overlap %': overlap_ratio * 100,
                            'Hypervolume': run_result['hypervolume'],
                            'Runtime (s)': run_result['runtime'],
                            'Solutions': run_result['num_solutions'],
                            'Tour Length': run_result.get('tour_length', np.nan)
                        })
        
        df = pd.DataFrame(data)
        
        if df.empty:
            print(f"No data for {alg_name}, skipping...")
            continue
        
        # FIX: Disable constrained_layout and increase height
        fig, axes = plt.subplots(1, 3, figsize=(15, 4.2), 
                                constrained_layout=False)
        
        # Use high-quality colormap
        n_colors = len(decomposition_sizes)
        if n_colors <= 10:
            colors = plt.cm.tab10(np.linspace(0, 1, n_colors))
        elif n_colors <= 20:
            colors = plt.cm.tab20(np.linspace(0, 1, n_colors))
        else:
            colors = plt.cm.viridis(np.linspace(0, 1, n_colors))
        
        # Plot each selected metric
        for ax_idx, metric_name in enumerate(plot_metrics[:3]):
            if metric_name not in metric_configs:
                print(f"Warning: Unknown metric '{metric_name}', skipping...")
                continue
            
            config = metric_configs[metric_name]
            ax = axes[ax_idx]
            
            # Plot line for each decomposition size
            for i, decomp_size in enumerate(decomposition_sizes):
                df_decomp = df[df['Decomposition Size'] == decomp_size]
                
                if df_decomp.empty:
                    continue
                
                # Aggregate by overlap %
                agg = df_decomp.groupby('Overlap %')[config['column']].agg(['mean', 'std']).reset_index()
                agg = agg.dropna()
                
                if agg.empty:
                    continue
                
                # Plot line
                ax.plot(agg['Overlap %'], agg['mean'], 
                       marker=config['marker'], linewidth=2, markersize=6,
                       color=colors[i], label=f'D={decomp_size}', alpha=0.85)
                
                # Error bands
                ax.fill_between(agg['Overlap %'], 
                               agg['mean'] - agg['std'],
                               agg['mean'] + agg['std'],
                               alpha=0.12, color=colors[i])
            
            # Axis labels
            ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=10)
            ax.set_ylabel(config['ylabel'], fontsize=10)
            
            # Title with direction indicator
            direction = "↑ better" if config['higher_better'] else "↓ better"
            ax.set_title(f'{config["ylabel"]} ({direction})', 
                        fontsize=11, fontweight='bold', pad=10)
            
            # Legend
            ax.legend(frameon=True, framealpha=0.9, edgecolor='0.8',
                     loc='best', fontsize=8, ncol=2 if n_colors > 6 else 1)
            
            # Grid
            ax.grid(True, alpha=0.3, linestyle='--')

            # Add baseline solver reference lines
            if baseline_solvers:
                baseline_metric_map = {
                    'hypervolume': 'hypervolume',
                    'runtime': 'runtime'
                }

                if metric_name in baseline_metric_map:
                    baseline_key = baseline_metric_map[metric_name]
                    baseline_colors = ['red', 'orange', 'purple', 'brown']
                    baseline_idx = 0

                    for solver_name, solver_metrics in baseline_solvers.items():
                        if baseline_key in solver_metrics:
                            value = solver_metrics[baseline_key]
                            color = baseline_colors[baseline_idx % len(baseline_colors)]

                            ax.axhline(y=value, color=color, linestyle='--',
                                      linewidth=2, alpha=0.7,
                                      label=f'{solver_name}')
                            baseline_idx += 1

                    # Update legend to include baselines
                    handles, labels = ax.get_legend_handles_labels()
                    ax.legend(handles, labels, frameon=True, framealpha=0.9,
                             edgecolor='0.8', loc='best', fontsize=8,
                             ncol=2 if len(labels) > 6 else 1)

        # FIX: Overall title with proper spacing
        # fig.suptitle(
        #     f'{alg_name}: Overlap Effect on Performance ({problem_type}{actual_size})',
        #     fontsize=13, fontweight='bold'
        # )
        
        # FIX: Manual spacing adjustment
        plt.subplots_adjust(top=0.90, bottom=0.12, left=0.06, right=0.98, wspace=0.2)
        
        # Save
        plot_file = os.path.join(
            output_dir,
            f'ablation_overlap_{alg_name}_{problem_type}_{problem_size}_{timestamp}.png'
        )
        plt.savefig(plot_file, dpi=300, bbox_inches='tight', pad_inches=0.1)
        plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight', pad_inches=0.1)
        print(f"Saved: {plot_file}")
        plt.close()
    
    print("\nICML-style plots generated!")


def load_and_replot(
    yaml_file: str,
    plot_metrics: List[str] = None,
    baseline_solvers: Dict[str, Dict[str, float]] = None,
    output_dir: str = None,
    output_suffix: str = 'replot'
):
    """
    Load saved results from YAML file and regenerate plots with new settings.

    This allows you to:
    - Add or modify baseline solver reference lines
    - Change which metrics to plot
    - Regenerate plots with different styles

    Parameters:
    -----------
    yaml_file : str
        Path to the saved YAML results file
    plot_metrics : List[str], optional
        Metrics to plot. Options: ['hypervolume', 'runtime', 'solutions', 'tour_length']
        If None, uses metrics from original run
    baseline_solvers : Dict[str, Dict[str, float]], optional
        Baseline solver results to show as reference lines.
        Example: {'WS': {'hypervolume': 0.5, 'runtime': 10.0}}
    output_dir : str, optional
        Directory to save new plots. If None, uses same directory as YAML file
    output_suffix : str
        Suffix to add to plot filenames (default: 'replot')

    Returns:
    --------
    dict : The loaded results dictionary

    Example:
    --------
    >>> baseline_solvers = {
    ...     'WS': {'hypervolume': 0.5, 'runtime': 10.0},
    ...     'NSGA-II': {'hypervolume': 0.48, 'runtime': 8.5}
    ... }
    >>> load_and_replot(
    ...     'ablation_results/ablation_overlap_decomp_BiTSP_20250124-143025.yaml',
    ...     plot_metrics=['hypervolume', 'runtime'],
    ...     baseline_solvers=baseline_solvers
    ... )
    """

    print("\n" + "="*80)
    print("LOADING SAVED RESULTS AND REGENERATING PLOTS")
    print("="*80)

    # Load YAML file
    print(f"\nLoading: {yaml_file}")
    with open(yaml_file, 'r') as f:
        saved_data = yaml.safe_load(f)

    # Extract metadata
    problem_type = saved_data['problem_type']
    problem_size = saved_data['problem_size']
    actual_size = saved_data['actual_size']
    decomposition_sizes = saved_data['decomposition_sizes']
    overlap_ratios = saved_data['overlap_ratios']
    algorithms = saved_data['algorithms']

    # Use plot_metrics from saved data if not provided
    if plot_metrics is None:
        plot_metrics = saved_data.get('plot_metrics', ['hypervolume', 'runtime', 'tour_length'])

    print(f"\nProblem: {problem_type} ({problem_size}), size={actual_size}")
    print(f"Algorithms: {algorithms}")
    print(f"Decomposition sizes: {decomposition_sizes}")
    print(f"Overlap ratios: {overlap_ratios}")
    print(f"Plotting metrics: {plot_metrics}")

    if baseline_solvers:
        print(f"Baseline solvers: {list(baseline_solvers.keys())}")

    # Convert saved results back to proper structure
    # saved_data['results'] has string keys that need to be converted back to floats
    results = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

    for alg_name, alg_data in saved_data['results'].items():
        for decomp_str, decomp_data in alg_data.items():
            decomp_size = int(decomp_str)
            for overlap_str, run_results in decomp_data.items():
                overlap_ratio = float(overlap_str)
                results[alg_name][decomp_size][overlap_ratio] = run_results

    # Determine output directory
    if output_dir is None:
        output_dir = os.path.dirname(yaml_file)
        if not output_dir:
            output_dir = '.'

    os.makedirs(output_dir, exist_ok=True)

    # Generate timestamp for new plots
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    timestamp = f"{timestamp}_{output_suffix}"

    # Generate plots
    print("\nRegenerating plots...")
    plot_overlap_decomposition(
        results, decomposition_sizes, overlap_ratios,
        problem_type, problem_size, output_dir, timestamp, actual_size,
        plot_metrics, baseline_solvers
    )

    print("\n" + "="*80)
    print("REPLOTTING COMPLETE!")
    print(f"Check {output_dir}/ for new plots")
    print("="*80)

    return dict(results)


# only this has convergrnce study  
def decomposition_diagnostic_with_convergence(
    algorithm_class,
    problem_type: str = 'BiObjectiveTSP',
    problem_size: str = 'medium',
    decomp_sizes: List[int] = None,
    num_runs: int = 3,
    output_dir: str = 'diagnostic_results',
    plot_options: Dict[str, bool] = None
):
    """
    Comprehensive diagnostic analyzing decomposition size effect with convergence metrics.
    """
    
    if plot_options is None:
        plot_options = {
            'n_subproblems': True,
            'main_iterations': True,
            'convergence_iterations': True,
            'total_work': True,
            'wall_time': True,
            'hypervolume': True,
            'tour_length': True,
        }
    
    # Get problem configuration
    problem_config = PROBLEM_REGISTRY[problem_type]
    problem_class = problem_config['class']
    problem_params = problem_config['sizes'][problem_size]
    
    if 'n_cities' in problem_params:
        actual_size = problem_params['n_cities']
    elif 'n_items' in problem_params:
        actual_size = problem_params['n_items']
    else:
        actual_size = 50
    
    # Auto-generate decomposition sizes
    if decomp_sizes is None:
        decomp_sizes = [5, 10, 15, 20, 30, 40]
        decomp_sizes = [d for d in decomp_sizes if d < actual_size]
    
    # Get reference point
    ref_point = problem_config['default_ref'](actual_size)
    print(f"\n{'='*80}")
    print(f"DECOMPOSITION DIAGNOSTIC WITH CONVERGENCE TRACKING")
    print(f"{'='*80}")
    print(f"Problem: {problem_type} (N={actual_size})")
    print(f"Reference point: {ref_point}")
    print(f"Decomposition sizes: {decomp_sizes}")
    print(f"Runs per size: {num_runs}")
    print(f"Looking for U-curves in: convergence iterations & total work")
    
    results = []
    
    for D in decomp_sizes:
        print(f"\n{'='*60}")
        print(f"Testing D={D}")
        print(f"{'='*60}")
        
        run_data = []
        
        for run in range(num_runs):
            # Algorithm parameters
            algorithm_params = {
                'learning_rate': 0.5,
                'ucb_coefficient': 3.0,
                'initial_temperature': 1.0,
                'temp_decay': 0.98,
                'hybrid_ratio': 0.7,
                'adaptive_hybrid': True,
                'decomposition_size': D,
                'overlap': max(2, D // 3),
                'max_iterations': 80,
                'nb_rounds': 10,
                'patience': 20,
                'use_lagrangian': True,
                'use_ftrl': True,
                'dual_step_size': 1.0,
                'use_accelerated_dual': True,
                'use_diminishing_overlap': False,
                'n_weight_vectors': 15,
            }
            
            # Create evaluator
            evaluator = MOCOEvaluator(reference_point=ref_point)
            
            start = time.time()
            
            # Run algorithm
            result = evaluator.evaluate_algorithm(
                algorithm_class=algorithm_class,
                problem_class=problem_class,
                algorithm_name=f"Decomp_D{D}",
                parameters=algorithm_params,
                problem_params=problem_params,
                num_runs=1
            )
            
            wall_time = time.time() - start
            
            # Calculate tour length/quality
            if problem_type == 'BiObjectiveTSP' and result.objectives:
                tour_length = np.mean([obj[0] + obj[1] for obj in result.objectives])
            elif problem_type == 'MultiObjectiveKnapsack' and result.objectives:
                tour_length = -np.mean([obj[0] + obj[1] for obj in result.objectives])
            elif problem_type == 'TriObjectiveTSP' and result.objectives:
                tour_length = np.mean([sum(obj) for obj in result.objectives])
            else:
                tour_length = 0.0
            
            # Try to get convergence stats
            convergence_stats = None
            if hasattr(result, '_wrapper_instance'):
                wrapper = result._wrapper_instance
                if hasattr(wrapper, 'get_aggregated_stats'):
                    convergence_stats = wrapper.get_aggregated_stats()
            
            # Store metrics
            run_result = {
                'wall_time': wall_time,
                'hypervolume': result.hypervolume,
                'num_solutions': result.num_nondominated,
                'tour_length': tour_length,
                'runtime': result.runtime
            }
            
            # Add convergence stats if available
            if convergence_stats:
                run_result.update({
                    'main_loop_iters': convergence_stats['main_loop_iters_mean'],
                    'convergence_iter': convergence_stats['convergence_iters_mean'],
                    'total_subprob_work': convergence_stats['total_subprob_evals_mean'],
                    'n_subprobs': convergence_stats['n_subproblems']
                })
            else:
                # Estimate if stats not available
                run_result.update({
                    'main_loop_iters': 80,
                    'convergence_iter': 80,
                    'total_subprob_work': 80 * max(1, actual_size // D),
                    'n_subprobs': max(1, actual_size // D)
                })
            
            run_data.append(run_result)
            
            print(f"  Run {run+1}: {wall_time:.2f}s, HV={result.hypervolume:.4f}, "
                  f"Iters={run_result['main_loop_iters']:.0f}, "
                  f"Converged@{run_result['convergence_iter']:.0f}, "
                  f"Work={run_result['total_subprob_work']:.0f}")
        
        # Aggregate
        results.append({
            'D': D,
            'n_subprobs': np.mean([r['n_subprobs'] for r in run_data]),
            'wall_time_mean': np.mean([r['wall_time'] for r in run_data]),
            'wall_time_std': np.std([r['wall_time'] for r in run_data]),
            'hv_mean': np.mean([r['hypervolume'] for r in run_data]),
            'hv_std': np.std([r['hypervolume'] for r in run_data]),
            'solutions_mean': np.mean([r['num_solutions'] for r in run_data]),
            'solutions_std': np.std([r['num_solutions'] for r in run_data]),
            'tour_len_mean': np.mean([r['tour_length'] for r in run_data]),
            'tour_len_std': np.std([r['tour_length'] for r in run_data]),
            'main_iters_mean': np.mean([r['main_loop_iters'] for r in run_data]),
            'main_iters_std': np.std([r['main_loop_iters'] for r in run_data]),
            'conv_iter_mean': np.mean([r['convergence_iter'] for r in run_data]),
            'conv_iter_std': np.std([r['convergence_iter'] for r in run_data]),
            'total_work_mean': np.mean([r['total_subprob_work'] for r in run_data]),
            'total_work_std': np.std([r['total_subprob_work'] for r in run_data]),
        })
        
        print(f"  → Avg: {results[-1]['wall_time_mean']:.2f}s, "
              f"HV={results[-1]['hv_mean']:.4f}, "
              f"Converged@{results[-1]['conv_iter_mean']:.1f}, "
              f"Work={results[-1]['total_work_mean']:.0f}")
    
    # Print comprehensive results
    print("\n" + "="*100)
    print("COMPREHENSIVE RESULTS")
    print("="*100)
    print(f"{'D':>4} | {'Subprobs':>8} | {'Wall Time':>12} | {'Main Iters':>12} | "
          f"{'Conv@Iter':>12} | {'Total Work':>12} | {'HV':>10} | {'Tour Len':>10}")
    print("-"*100)
    
    for r in results:
        print(f"{r['D']:>4} | {r['n_subprobs']:>8.1f} | "
              f"{r['wall_time_mean']:>7.2f}±{r['wall_time_std']:>3.2f} | "
              f"{r['main_iters_mean']:>7.1f}±{r['main_iters_std']:>3.1f} | "
              f"{r['conv_iter_mean']:>7.1f}±{r['conv_iter_std']:>3.1f} | "
              f"{r['total_work_mean']:>7.0f}±{r['total_work_std']:>3.0f} | "
              f"{r['hv_mean']:>6.4f}±{r['hv_std']:>.3f} | "
              f"{r['tour_len_mean']:>6.2f}±{r['tour_len_std']:>3.2f}")
    print("="*100)
    
    # ✅ CREATE timestamp ONCE
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    os.makedirs(output_dir, exist_ok=True)

    # Save to CSV
    csv_file = save_decomposition_results_to_csv(
        results, problem_type, problem_size, output_dir, timestamp
    )

    # Analysis
    conv_iters = [r['conv_iter_mean'] for r in results]
    total_works = [r['total_work_mean'] for r in results]
    wall_times = [r['wall_time_mean'] for r in results]
    hvs = [r['hv_mean'] for r in results]
    tour_lens = [r['tour_len_mean'] for r in results]
    
    min_conv_idx = conv_iters.index(min(conv_iters))
    min_work_idx = total_works.index(min(total_works))
    min_time_idx = wall_times.index(min(wall_times))
    max_hv_idx = hvs.index(max(hvs))
    min_tour_idx = tour_lens.index(min(tour_lens))
    
    print(f"\n🎯 U-CURVE ANALYSIS:")
    print(f"  Fastest convergence: D={results[min_conv_idx]['D']} ({conv_iters[min_conv_idx]:.1f} iters) ⭐")
    print(f"  Least total work: D={results[min_work_idx]['D']} ({total_works[min_work_idx]:.0f} evals) ⭐")
    print(f"  Fastest runtime: D={results[min_time_idx]['D']} ({wall_times[min_time_idx]:.2f}s)")
    print(f"  Best hypervolume: D={results[max_hv_idx]['D']} ({hvs[max_hv_idx]:.4f})")
    print(f"  Best tour length: D={results[min_tour_idx]['D']} ({tour_lens[min_tour_idx]:.2f})")
    print(f"  Theoretical optimal: D≈√{actual_size}≈{int(np.sqrt(actual_size))}")
    
    # Check for U-curves
    print(f"\n📊 CONVERGENCE ITERATIONS TREND (U-curve expected):")
    for i, r in enumerate(results):
        if i == min_conv_idx:
            marker = "⭐ MINIMUM (U-curve)"
        elif i == 0:
            marker = "↓" if conv_iters[i] > conv_iters[i+1] else "↑"
        elif i == len(results) - 1:
            marker = "↑" if conv_iters[i] > conv_iters[i-1] else "↓"
        else:
            if conv_iters[i] < conv_iters[i-1] and conv_iters[i] < conv_iters[i+1]:
                marker = "◆ LOCAL MIN"
            elif conv_iters[i] > conv_iters[i-1] and conv_iters[i] > conv_iters[i+1]:
                marker = "▲ LOCAL MAX"
            else:
                marker = "↑" if conv_iters[i] > conv_iters[i-1] else "↓"
        print(f"  D={r['D']:>2}: {conv_iters[i]:>5.1f} iters {marker}")
    
    print(f"\n📊 TOTAL WORK TREND (U-curve expected):")
    for i, r in enumerate(results):
        if i == min_work_idx:
            marker = "⭐ MINIMUM (U-curve)"
        else:
            marker = "↑" if i > 0 and total_works[i] > total_works[i-1] else "↓"
        print(f"  D={r['D']:>2}: {total_works[i]:>6.0f} evals {marker}")
    
    # ✅ CREATE PLOTS
    n_plots = sum(plot_options.values())
    if n_plots == 0:
        print("\nNo plots enabled!")
        return results
    
    # Create figure
    if n_plots <= 4:
        fig, axes = plt.subplots(2, 2, figsize=(10, 7), constrained_layout=False)
    elif n_plots <= 6:
        fig, axes = plt.subplots(2, 3, figsize=(15, 7), constrained_layout=False)
    else:
        fig, axes = plt.subplots(3, 3, figsize=(15, 10), constrained_layout=False)
    
    axes = axes.flatten()
    D_vals = [r['D'] for r in results]
    plot_idx = 0
    
    # Color palette
    color_map = {
        'n_subproblems': '#0173B2',
        'main_iterations': '#029E73',
        'convergence_iterations': '#CC3311',
        'total_work': '#9933CC',
        'wall_time': '#0173B2',
        'hypervolume': '#029E73',
        'tour_length': '#333333',
    }
    
    # 1. Number of subproblems
    if plot_options.get('n_subproblems', False):
        n_subprobs = [r['n_subprobs'] for r in results]
        axes[plot_idx].plot(D_vals, n_subprobs, 'o-', 
                           color=color_map['n_subproblems'],
                           linewidth=2, markersize=6)
        axes[plot_idx].set_xlabel('Decomposition Size (D)', fontsize=10)
        axes[plot_idx].set_ylabel('Number of Subproblems (N/D)', fontsize=10)
        axes[plot_idx].set_title('Subproblem Count (↓ with D)', 
                                fontsize=11, fontweight='bold', pad=8)
        axes[plot_idx].grid(True, alpha=0.3, linestyle='--')
        plot_idx += 1
    
    # 2. Main loop iterations
    if plot_options.get('main_iterations', False):
        main_iters = [r['main_iters_mean'] for r in results]
        main_std = [r['main_iters_std'] for r in results]
        axes[plot_idx].errorbar(D_vals, main_iters, yerr=main_std,
                               fmt='o-', color=color_map['main_iterations'],
                               linewidth=2, markersize=6, capsize=4, capthick=1.5)
        axes[plot_idx].set_xlabel('Decomposition Size (D)', fontsize=10)
        axes[plot_idx].set_ylabel('Main Loop Iterations', fontsize=10)
        axes[plot_idx].set_title('Iterations Run', fontsize=11, fontweight='bold', pad=8)
        axes[plot_idx].grid(True, alpha=0.3, linestyle='--')
        plot_idx += 1
    
    # 3. Convergence iterations (U-CURVE!)
    if plot_options.get('convergence_iterations', False):
        conv_std = [r['conv_iter_std'] for r in results]
        
        axes[plot_idx].errorbar(D_vals, conv_iters, yerr=conv_std,
                               fmt='o-', color=color_map['convergence_iterations'],
                               linewidth=2, markersize=6, capsize=4, capthick=1.5)
        
        # Highlight minimum
        axes[plot_idx].scatter(results[min_conv_idx]['D'], conv_iters[min_conv_idx],
                              s=400, c='gold', marker='*', 
                              edgecolors=color_map['convergence_iterations'], 
                              linewidths=2.5, zorder=5)
        
        # Theoretical optimum
        sqrt_n = int(np.sqrt(actual_size))
        axes[plot_idx].axvline(sqrt_n, color='green', linestyle='--',
                              alpha=0.4, linewidth=1.5, label=f'√N={sqrt_n}')
        
        axes[plot_idx].set_xlabel('Decomposition Size (D)', fontsize=10)
        axes[plot_idx].set_ylabel('Iterations to Converge', fontsize=10)
        axes[plot_idx].set_title('Convergence Speed (U-curve!) ⭐', 
                                fontsize=11, fontweight='bold', color='#CC3311', pad=8)
        axes[plot_idx].legend(fontsize=8, framealpha=0.9, edgecolor='0.8')
        axes[plot_idx].grid(True, alpha=0.3, linestyle='--')
        plot_idx += 1
    
    # 4. Total work (U-CURVE!)
    if plot_options.get('total_work', False):
        total_std = [r['total_work_std'] for r in results]
        
        axes[plot_idx].errorbar(D_vals, total_works, yerr=total_std,
                               fmt='o-', color=color_map['total_work'],
                               linewidth=2, markersize=6, capsize=4, capthick=1.5)
        
        # Highlight minimum
        axes[plot_idx].scatter(results[min_work_idx]['D'], total_works[min_work_idx],
                              s=400, c='gold', marker='*', 
                              edgecolors=color_map['total_work'], 
                              linewidths=2.5, zorder=5)
        
        # Theoretical optimum
        sqrt_n = int(np.sqrt(actual_size))
        axes[plot_idx].axvline(sqrt_n, color='green', linestyle='--',
                              alpha=0.4, linewidth=1.5, label=f'√N={sqrt_n}')
        
        axes[plot_idx].set_xlabel('Decomposition Size (D)', fontsize=10)
        axes[plot_idx].set_ylabel('Total Subproblem Evaluations', fontsize=10)
        axes[plot_idx].set_title('Total Work (U-curve!) ⭐', 
                                fontsize=11, fontweight='bold', color='#9933CC', pad=8)
        axes[plot_idx].legend(fontsize=8, framealpha=0.9, edgecolor='0.8')
        axes[plot_idx].grid(True, alpha=0.3, linestyle='--')
        plot_idx += 1
    
    # 5. Wall time
    if plot_options.get('wall_time', False):
        wall_std = [r['wall_time_std'] for r in results]
        
        axes[plot_idx].errorbar(D_vals, wall_times, yerr=wall_std,
                               fmt='o-', color=color_map['wall_time'],
                               linewidth=2, markersize=6, capsize=4, capthick=1.5)
        
        axes[plot_idx].scatter(results[min_time_idx]['D'], wall_times[min_time_idx],
                              s=400, c='gold', marker='*', 
                              edgecolors=color_map['wall_time'], 
                              linewidths=2.5, zorder=5)
        
        axes[plot_idx].set_xlabel('Decomposition Size (D)', fontsize=10)
        axes[plot_idx].set_ylabel('Wall Clock Time (s)', fontsize=10)
        axes[plot_idx].set_title('Total Runtime (Monotonic ↓)', 
                                fontsize=11, fontweight='bold', pad=8)
        axes[plot_idx].grid(True, alpha=0.3, linestyle='--')
        plot_idx += 1
    
    # 6. Hypervolume
    if plot_options.get('hypervolume', False):
        hv_std = [r['hv_std'] for r in results]
        
        axes[plot_idx].errorbar(D_vals, hvs, yerr=hv_std,
                               fmt='o-', color=color_map['hypervolume'],
                               linewidth=2, markersize=6, capsize=4, capthick=1.5)
        
        axes[plot_idx].scatter(results[max_hv_idx]['D'], hvs[max_hv_idx],
                              s=400, c='gold', marker='*', 
                              edgecolors=color_map['hypervolume'], 
                              linewidths=2.5, zorder=5)
        
        axes[plot_idx].set_xlabel('Decomposition Size (D)', fontsize=10)
        axes[plot_idx].set_ylabel('Hypervolume (↑ better)', fontsize=10)
        axes[plot_idx].set_title('Solution Quality', 
                                fontsize=11, fontweight='bold', pad=8)
        axes[plot_idx].grid(True, alpha=0.3, linestyle='--')
        plot_idx += 1
    
    # 7. Tour length
    if plot_options.get('tour_length', False):
        tour_std = [r['tour_len_std'] for r in results]
        
        axes[plot_idx].errorbar(D_vals, tour_lens, yerr=tour_std,
                               fmt='o-', color=color_map['tour_length'],
                               linewidth=2, markersize=6, capsize=4, capthick=1.5)
        
        axes[plot_idx].scatter(results[min_tour_idx]['D'], tour_lens[min_tour_idx],
                              s=400, c='gold', marker='*', 
                              edgecolors=color_map['tour_length'], 
                              linewidths=2.5, zorder=5)
        
        axes[plot_idx].set_xlabel('Decomposition Size (D)', fontsize=10)
        ylabel = 'Tour Length (↓ better)' if problem_type in ['BiObjectiveTSP', 'TriObjectiveTSP'] else 'Value (↑ better)'
        axes[plot_idx].set_ylabel(ylabel, fontsize=10)
        axes[plot_idx].set_title('Tour Quality', 
                                fontsize=11, fontweight='bold', pad=8)
        axes[plot_idx].grid(True, alpha=0.3, linestyle='--')
        plot_idx += 1
    
    # Hide unused subplots
    for i in range(plot_idx, len(axes)):
        axes[i].axis('off')
    
    # Overall title
    fig.suptitle(
        f'Decomposition Size Analysis with Convergence Tracking\n'
        f'{problem_type}, N={actual_size} (Looking for U-curves!)',
        fontsize=13, fontweight='bold'
    )
    
    # Manual spacing
    plt.subplots_adjust(top=0.92, bottom=0.08, left=0.06, right=0.98, 
                       hspace=0.35, wspace=0.25)
    
    # Save
    plot_file = os.path.join(output_dir, 
                            f'decomp_diagnostic_convergence_{problem_type}_{timestamp}.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight', pad_inches=0.1)
    print(f"\n📊 ICML-style plots saved: {plot_file}")
    plt.close()  # Don't call plt.show() - it blocks in scripts
    
    return results



def save_results_to_csv(results, decomposition_sizes, overlap_ratios, 
                        problem_type, problem_size, output_dir, timestamp):
    """
    Save ablation results to CSV for later analysis
    
    Creates two CSV files:
    1. Individual runs (one row per run)
    2. Aggregated statistics (mean/std per configuration)
    """
    
    # Prepare data for individual runs CSV
    individual_data = []
    for alg_name, alg_results in results.items():
        for decomp_size in decomposition_sizes:
            for overlap_ratio in overlap_ratios:
                if overlap_ratio in alg_results[decomp_size]:
                    for run_idx, run_result in enumerate(alg_results[decomp_size][overlap_ratio]):
                        individual_data.append({
                            'algorithm': alg_name,
                            'problem_type': problem_type,
                            'problem_size': problem_size,
                            'decomposition_size': decomp_size,
                            'overlap_ratio': overlap_ratio,
                            'overlap_absolute': int(decomp_size * overlap_ratio),
                            'run_number': run_idx + 1,
                            'hypervolume': run_result['hypervolume'],
                            'runtime_seconds': run_result['runtime'],
                            'num_solutions': run_result['num_solutions'],
                            'tour_length': run_result.get('tour_length', np.nan),
                            'timestamp': timestamp
                        })
    
    # Save individual runs
    df_individual = pd.DataFrame(individual_data)
    csv_file_individual = os.path.join(
        output_dir, 
        f'ablation_individual_runs_{problem_type}_{problem_size}_{timestamp}.csv'
    )
    df_individual.to_csv(csv_file_individual, index=False)
    print(f"Individual runs saved to: {csv_file_individual}")
    
    # Prepare aggregated statistics CSV
    aggregated_data = []
    for alg_name, alg_results in results.items():
        for decomp_size in decomposition_sizes:
            for overlap_ratio in overlap_ratios:
                if overlap_ratio in alg_results[decomp_size]:
                    runs = alg_results[decomp_size][overlap_ratio]
                    
                    # Calculate statistics
                    hv_values = [r['hypervolume'] for r in runs]
                    runtime_values = [r['runtime'] for r in runs]
                    solution_values = [r['num_solutions'] for r in runs]
                    tour_values = [r['tour_length'] for r in runs if r.get('tour_length') is not None]
                    
                    aggregated_data.append({
                        'algorithm': alg_name,
                        'problem_type': problem_type,
                        'problem_size': problem_size,
                        'decomposition_size': decomp_size,
                        'overlap_ratio': overlap_ratio,
                        'overlap_absolute': int(decomp_size * overlap_ratio),
                        'overlap_percentage': overlap_ratio * 100,
                        'n_runs': len(runs),
                        # Hypervolume stats
                        'hv_mean': np.mean(hv_values),
                        'hv_std': np.std(hv_values),
                        'hv_min': np.min(hv_values),
                        'hv_max': np.max(hv_values),
                        'hv_median': np.median(hv_values),
                        # Runtime stats
                        'runtime_mean': np.mean(runtime_values),
                        'runtime_std': np.std(runtime_values),
                        'runtime_min': np.min(runtime_values),
                        'runtime_max': np.max(runtime_values),
                        # Solution count stats
                        'solutions_mean': np.mean(solution_values),
                        'solutions_std': np.std(solution_values),
                        # Tour length stats
                        'tour_length_mean': np.mean(tour_values) if tour_values else np.nan,
                        'tour_length_std': np.std(tour_values) if tour_values else np.nan,
                        'timestamp': timestamp
                    })
    
    # Save aggregated statistics
    df_aggregated = pd.DataFrame(aggregated_data)
    csv_file_aggregated = os.path.join(
        output_dir, 
        f'ablation_aggregated_{problem_type}_{problem_size}_{timestamp}.csv'
    )
    df_aggregated.to_csv(csv_file_aggregated, index=False)
    print(f"Aggregated statistics saved to: {csv_file_aggregated}")
    
    return csv_file_individual, csv_file_aggregated

def save_decomposition_results_to_csv(results, problem_type, problem_size, 
                                      output_dir, timestamp):
    """Save decomposition diagnostic results to CSV"""
    
    # Convert results list to DataFrame
    df = pd.DataFrame(results)
    
    # Add metadata columns
    df.insert(0, 'problem_type', problem_type)
    df.insert(1, 'problem_size', problem_size)
    df.insert(2, 'timestamp', timestamp)
    
    # Reorder columns for clarity
    column_order = [
        'problem_type', 'problem_size', 'timestamp',
        'D', 'n_subprobs',
        'wall_time_mean', 'wall_time_std',
        'main_iters_mean', 'main_iters_std',
        'conv_iter_mean', 'conv_iter_std',
        'total_work_mean', 'total_work_std',
        'hv_mean', 'hv_std',
        'solutions_mean', 'solutions_std',
        'tour_len_mean', 'tour_len_std'
    ]
    
    df = df[column_order]
    
    # Save to CSV
    csv_file = os.path.join(
        output_dir, 
        f'decomposition_diagnostic_{problem_type}_{problem_size}_{timestamp}.csv'
    )
    df.to_csv(csv_file, index=False)
    print(f"\n📊 CSV saved to: {csv_file}")
    
    return csv_file

#---------------

# def experiment_overlap_harm_detection(
#     algorithm_class,
#     problem_type: str = 'BiTSP',
#     problem_size: str = 'medium',
#     decomposition_sizes: List[int] = None,
#     overlap_ratios: List[float] = None,
#     num_runs: int = 5,
#     output_dir: str = 'harm_detection_results',
# ):
#     """
#     Experiment to identify when increasing overlap becomes harmful.
    
#     Measures:
#     1. Computational Redundancy Ratio
#     2. Marginal Benefit (ΔHV / ΔRuntime)
#     3. Convergence Penalty
    
#     Returns harm thresholds for each decomposition size.
#     """
    
#     print("\n" + "="*80)
#     print("EXPERIMENT: OVERLAP HARM DETECTION")
#     print("="*80)
    
#     if problem_type not in PROBLEM_REGISTRY:
#         raise ValueError(f"Unknown problem type: {problem_type}")
    
#     problem_config = PROBLEM_REGISTRY[problem_type]
#     problem_class = problem_config['class']
#     problem_params = problem_config['sizes'][problem_size]
    
#     # Get actual problem size
#     if 'n_cities' in problem_params:
#         actual_size = problem_params['n_cities']
#     elif 'n_items' in problem_params:
#         actual_size = problem_params['n_items']
#     else:
#         actual_size = 50
    
#     # Auto-generate decomposition sizes if not provided
#     if decomposition_sizes is None:
#         decomposition_sizes = [10, 20, 30, 50]
    
#     # Dense overlap ratios for better derivative estimation
#     if overlap_ratios is None:
#         overlap_ratios = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    
#     ref_point = problem_config['default_ref'](actual_size)
    
#     print(f"\nProblem: {problem_type} ({problem_size}), size={actual_size}")
#     print(f"Decomposition sizes: {decomposition_sizes}")
#     print(f"Overlap ratios: {overlap_ratios}")
#     print(f"Runs per config: {num_runs}")
    
#     # FIX: Results storage with proper initialization
#     results = {}
#     for decomp_size in decomposition_sizes:
#         results[decomp_size] = {}
#         for overlap_ratio in overlap_ratios:
#             results[decomp_size][overlap_ratio] = []  # Initialize as list
    
#     # Run experiments with detailed tracking
#     for decomp_size in decomposition_sizes:
#         print(f"\n{'='*60}")
#         print(f"Testing D={decomp_size}")
#         print(f"{'='*60}")
        
#         for overlap_ratio in overlap_ratios:
#             overlap = max(0, int(decomp_size * overlap_ratio))
            
#             print(f"\n  Overlap: {overlap_ratio*100:.0f}% ({overlap} abs)")
            
#             # Calculate theoretical redundancy
#             stride = max(1, decomp_size - overlap)
#             n_subproblems = max(1, (actual_size - decomp_size) // stride + 1)
            
#             # Theoretical metrics
#             total_indices_processed = n_subproblems * decomp_size
#             unique_indices = actual_size  # All indices covered at least once
#             theoretical_redundancy = total_indices_processed / unique_indices
            
#             print(f"    Subproblems: {n_subproblems}, Stride: {stride}")
#             print(f"    Theoretical redundancy: {theoretical_redundancy:.2f}x")
            
#             # Base parameters
#             base_params = {
#                 'learning_rate': 0.5,
#                 'ucb_coefficient': 3.0,
#                 'initial_temperature': 1.0,
#                 'temp_decay': 0.98,
#                 'hybrid_ratio': 0.7,
#                 'adaptive_hybrid': True,
#                 'decomposition_size': decomp_size,
#                 'overlap': overlap,
#                 'max_iterations': 100,
#                 'nb_rounds': 5,
#                 'patience': 20,
#                 'use_lagrangian': True,
#                 'use_ftrl': True,
#                 'dual_step_size': 1.0,
#                 'use_accelerated_dual': True,
#                 'use_diminishing_overlap': False,
#                 'n_weight_vectors': 15,
#             }
            
#             # Run multiple times
#             for run in range(num_runs):
#                 try:
#                     evaluator = MOCOEvaluator(reference_point=ref_point)
                    
#                     start_time = time.time()
#                     result = evaluator.evaluate_algorithm(
#                         algorithm_class=algorithm_class,
#                         problem_class=problem_class,
#                         algorithm_name=f"D{decomp_size}_O{int(overlap_ratio*100)}",
#                         parameters=base_params,
#                         problem_params=problem_params,
#                         num_runs=1
#                     )
#                     wall_time = time.time() - start_time
                    

#                     # Calculate tour length
#                     tour_length = None
#                     if problem_type == 'BiTSP' and result.objectives:
#                         tour_length = np.mean([obj[0] + obj[1] for obj in result.objectives])
#                     elif problem_type == 'BiKP' and result.objectives:
#                         tour_length = -np.mean([obj[0] + obj[1] for obj in result.objectives])
                    
#                     # Try to get convergence stats
#                     convergence_stats = None
#                     convergence_iter = 80  # default
#                     total_work = 80 * n_subproblems  # default estimate
#                     stagnation_count = 0
                    
#                     if hasattr(result, '_wrapper_instance'):
#                         wrapper = result._wrapper_instance
#                         if hasattr(wrapper, 'get_aggregated_stats'):
#                             convergence_stats = wrapper.get_aggregated_stats()
#                             convergence_iter = convergence_stats.get('convergence_iters_mean', 80)
#                             total_work = convergence_stats.get('total_subprob_evals_mean', total_work)
                        
#                         # Try to get per-iteration improvement tracking
#                         if hasattr(wrapper, 'iteration_history'):
#                             history = wrapper.iteration_history
#                             if len(history) > 1:
#                                 improvements = [
#                                     abs(history[i] - history[i-1]) / max(abs(history[i-1]), 1e-10)
#                                     for i in range(1, len(history))
#                                 ]
#                                 stagnation_count = sum(1 for imp in improvements if imp < 1e-4)
                    
#                     # Calculate empirical redundancy
#                     empirical_redundancy = total_work / actual_size
                    
#                     # Store comprehensive metrics
#                     results[decomp_size][overlap_ratio].append({
#                         'hypervolume': result.hypervolume,
#                         'runtime': result.runtime,
#                         'wall_time': wall_time,
#                         'num_solutions': result.num_nondominated,
#                         'tour_length': tour_length,
#                         'n_subproblems': n_subproblems,
#                         'convergence_iter': convergence_iter,
#                         'total_work': total_work,
#                         'theoretical_redundancy': theoretical_redundancy,
#                         'empirical_redundancy': empirical_redundancy,
#                         'stagnation_count': stagnation_count,
#                         'stride': stride,
#                     })
                    
#                     print(f"      Run {run+1}/{num_runs} - HV: {result.hypervolume:.4f}, "
#                           f"Time: {wall_time:.2f}s, Conv@: {convergence_iter:.0f}, "
#                           f"RedundEmp: {empirical_redundancy:.2f}x")
                    
#                 except Exception as e:
#                     print(f"      Run {run+1} failed: {e}")
#                     import traceback
#                     traceback.print_exc()
    
#     # Analyze and compute harm thresholds
#     timestamp = time.strftime("%Y%m%d-%H%M%S")
#     os.makedirs(output_dir, exist_ok=True)
    
#     harm_analysis = analyze_harm_thresholds(
#         results, decomposition_sizes, overlap_ratios, 
#         actual_size, output_dir, timestamp
#     )
    
#     # Generate comprehensive plots
#     plot_harm_detection(
#         results, decomposition_sizes, overlap_ratios,
#         harm_analysis, problem_type, problem_size, 
#         output_dir, timestamp, actual_size
#     )
    
#     # Save detailed results
#     save_harm_detection_to_csv(
#         results, harm_analysis, decomposition_sizes, overlap_ratios,
#         problem_type, problem_size, output_dir, timestamp
#     )
    
#     return results, harm_analysis


def experiment_overlap_harm_detection(
    algorithm_class,
    problem_type: str = 'BiTSP',
    problem_size: str = 'medium',
    decomposition_sizes: List[int] = None,
    overlap_ratios: List[float] = None,
    num_runs: int = 5,
    output_dir: str = 'harm_detection_results',
):
    """
    Experiment to identify when increasing overlap becomes harmful.
    
    Measures:
    1. Computational Redundancy Ratio
    2. Marginal Benefit (ΔHV / ΔRuntime)
    3. Convergence Penalty
    
    Returns harm thresholds for each decomposition size.
    """
    
    print("\n" + "="*80)
    print("EXPERIMENT: OVERLAP HARM DETECTION")
    print("="*80)
    
    if problem_type not in PROBLEM_REGISTRY:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    problem_config = PROBLEM_REGISTRY[problem_type]
    problem_class = problem_config['class']
    problem_params = problem_config['sizes'][problem_size]
    
    # Get actual problem size
    if 'n_cities' in problem_params:
        actual_size = problem_params['n_cities']
    elif 'n_items' in problem_params:
        actual_size = problem_params['n_items']
    else:
        actual_size = 50
    
    # Auto-generate decomposition sizes if not provided
    if decomposition_sizes is None:
        decomposition_sizes = [10, 20, 30, 50]
    
    # Dense overlap ratios for better derivative estimation
    if overlap_ratios is None:
        overlap_ratios = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    
    ref_point = problem_config['default_ref'](actual_size)
    
    print(f"\nProblem: {problem_type} ({problem_size}), size={actual_size}")
    print(f"Decomposition sizes: {decomposition_sizes}")
    print(f"Overlap ratios: {overlap_ratios}")
    print(f"Runs per config: {num_runs}")
    
    # Results storage with proper initialization
    results = {}
    for decomp_size in decomposition_sizes:
        results[decomp_size] = {}
        for overlap_ratio in overlap_ratios:
            results[decomp_size][overlap_ratio] = []
    
    # Run experiments with detailed tracking
    for decomp_size in decomposition_sizes:
        print(f"\n{'='*60}")
        print(f"Testing D={decomp_size}")
        print(f"{'='*60}")
        
        for overlap_ratio in overlap_ratios:
            overlap = max(0, int(decomp_size * overlap_ratio))
            
            print(f"\n  Overlap: {overlap_ratio*100:.0f}% ({overlap} abs)")
            
            # Calculate theoretical redundancy
            stride = max(1, decomp_size - overlap)
            n_subproblems = max(1, (actual_size - decomp_size) // stride + 1)
            
            # Theoretical metrics
            total_indices_processed = n_subproblems * decomp_size
            unique_indices = actual_size
            theoretical_redundancy = total_indices_processed / unique_indices
            
            print(f"    Subproblems: {n_subproblems}, Stride: {stride}")
            print(f"    Theoretical redundancy: {theoretical_redundancy:.2f}x")
            
            # Base parameters
            # base_params = {
            #     'learning_rate': 0.5,
            #     'ucb_coefficient': 3.0,
            #     'initial_temperature': 1.0,
            #     'temp_decay': 0.98,
            #     'hybrid_ratio': 0.7,
            #     'adaptive_hybrid': True,
            #     'decomposition_size': decomp_size,
            #     'overlap': overlap,
            #     'max_iterations': 100,
            #     'nb_rounds': 5,
            #     'patience': 20,
            #     'use_lagrangian': True,
            #     'use_ftrl': True,
            #     'dual_step_size': 1.0,
            #     'use_accelerated_dual': True,
            #     'use_diminishing_overlap': False,
            #     'n_weight_vectors': 15,
            #     'use_correlation_decomposition': True,
            #     'use_elite_decomposition': True,
            #     'use_metric_decomposition': True,
            # }

            base_params = {
                    'learning_rate': 0.1,
                    'initial_temperature': 1.0,
                    'temp_decay': 0.98,
                    'hybrid_ratio': 0.5,
                    'adaptive_hybrid': True,
                    'decomposition_size': decomp_size,
                    'overlap': overlap,
                    'max_iterations': 80,
                    'nb_rounds': 15,
                    'patience': 30,
                    'use_lagrangian': True,
                    'use_ftrl': True,
                    'dual_step_size': 1.0,
                    'use_accelerated_dual': True,
                    'use_diminishing_overlap': False,
                    'n_weight_vectors': 30,
                }
            
            # Run multiple times
            for run in range(num_runs):
                try:
                    evaluator = MOCOEvaluator(reference_point=ref_point)
                
                    start_time = time.time()
                    result = evaluator.evaluate_algorithm(
                        algorithm_class=algorithm_class,
                        problem_class=problem_class,
                        algorithm_name=f"D{decomp_size}_O{int(overlap_ratio*100)}",
                        parameters=base_params,
                        problem_params=problem_params,
                        num_runs=1
                    )
                    wall_time = time.time() - start_time
                    
                    # Calculate tour length
                    tour_length = None
                    if problem_type == 'BiTSP' and result.objectives:
                        tour_length = np.mean([obj[0] + obj[1] for obj in result.objectives])
                    elif problem_type == 'BiKP' and result.objectives:
                        tour_length = -np.mean([obj[0] + obj[1] for obj in result.objectives])
                    
                    # ✅ SIMPLE FIX: Get convergence stats from result.convergence_stats
                    # The wrapper stores this during run() via get_aggregated_stats()
                    convergence_iter = 80
                    total_work = 80 * n_subproblems
                    main_loop_iters = 80
                    convergence_rate = 0.0
                    n_converged = 0  # ✅ Initialize BEFORE the if block
                    n_total = 15     # ✅ Initialize BEFORE the if block

                    if hasattr(result, 'convergence_stats'):
                        stats = result.convergence_stats
                        
                        main_loop_iters = stats.get('main_loop_iters_mean', 80)
                        conv_iter_mean = stats.get('convergence_iters_mean', 80)
                        total_work = stats.get('total_subprob_evals_mean', 80 * n_subproblems)
                        
                        n_converged = stats.get('n_weights_converged', 0)
                        n_total = stats.get('n_weights_total', 15)
                        convergence_rate = stats.get('convergence_rate', 0.0)
                        
                        # Use the mean convergence iteration
                        convergence_iter = conv_iter_mean

                    empirical_redundancy = total_work / actual_size

                    # Store with more detail
                    results[decomp_size][overlap_ratio].append({
                        'hypervolume': result.hypervolume,
                        'runtime': result.runtime,
                        'wall_time': wall_time,
                        'num_solutions': result.num_nondominated,
                        'tour_length': tour_length,
                        'n_subproblems': n_subproblems,
                        'convergence_iter': convergence_iter,  # Average of converged weights
                        'main_loop_iters': main_loop_iters,
                        'total_work': total_work,
                        'theoretical_redundancy': theoretical_redundancy,
                        'empirical_redundancy': empirical_redundancy,
                        'stagnation_count': 0,
                        'stride': stride,
                        'convergence_rate': convergence_rate,  # NEW: % that converged
                        'n_converged': n_converged,
                        'n_total': n_total,
                    })

                    # Better printing with convergence rate
                    print(f"      Run {run+1}/{num_runs} - HV: {result.hypervolume:.4f}, "
                        f"Time: {wall_time:.2f}s, Conv@: {convergence_iter:.0f} "
                        f"({n_converged}/{n_total} = {convergence_rate*100:.0f}%), "
                        f"RedundEmp: {empirical_redundancy:.2f}x")
                    
                except Exception as e:
                    print(f"      Run {run+1} failed: {e}")
                    import traceback
                    traceback.print_exc()
    
    # Analyze and compute harm thresholds
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    os.makedirs(output_dir, exist_ok=True)
    
    harm_analysis = analyze_harm_thresholds(
        results, decomposition_sizes, overlap_ratios, 
        actual_size, output_dir, timestamp
    )
    
    # Generate comprehensive plots
    plot_harm_detection(
        results, decomposition_sizes, overlap_ratios,
        harm_analysis, problem_type, problem_size, 
        output_dir, timestamp, actual_size
    )
    
    # Save detailed results
    save_harm_detection_to_csv(
        results, harm_analysis, decomposition_sizes, overlap_ratios,
        problem_type, problem_size, output_dir, timestamp
    )
    
    return results, harm_analysis


def analyze_harm_thresholds(
    results, decomposition_sizes, overlap_ratios, 
    actual_size, output_dir, timestamp
):
    """
    Analyze results to identify harm thresholds for each decomposition size.
    
    Returns dict with thresholds based on three criteria:
    1. Runtime multiplier > 2x baseline
    2. Redundancy ratio > 2.5x
    3. Marginal efficiency drops significantly
    """
    
    print("\n" + "="*80)
    print("HARM THRESHOLD ANALYSIS")
    print("="*80)
    
    harm_analysis = {}
    
    for decomp_size in decomposition_sizes:
        print(f"\n{'='*60}")
        print(f"Analyzing D={decomp_size}")
        print(f"{'='*60}")
        
        # Aggregate statistics for this decomposition size
        overlap_data = []
        for overlap_ratio in overlap_ratios:
            if overlap_ratio in results[decomp_size]:
                runs = results[decomp_size][overlap_ratio]
                
                overlap_data.append({
                    'overlap_ratio': overlap_ratio,
                    'overlap_pct': overlap_ratio * 100,
                    'hv_mean': np.mean([r['hypervolume'] for r in runs]),
                    'hv_std': np.std([r['hypervolume'] for r in runs]),
                    'runtime_mean': np.mean([r['wall_time'] for r in runs]),
                    'runtime_std': np.std([r['wall_time'] for r in runs]),
                    'redundancy_mean': np.mean([r['empirical_redundancy'] for r in runs]),
                    'redundancy_std': np.std([r['empirical_redundancy'] for r in runs]),
                    'convergence_mean': np.mean([r['convergence_iter'] for r in runs]),
                    'convergence_std': np.std([r['convergence_iter'] for r in runs]),
                    'stagnation_mean': np.mean([r['stagnation_count'] for r in runs]),
                })
        
        if len(overlap_data) < 3:
            print("  ⚠️ Insufficient data for analysis")
            continue
        
        # Baseline: use 20% overlap as reference (or closest available)
        baseline_idx = min(range(len(overlap_data)), 
                          key=lambda i: abs(overlap_data[i]['overlap_ratio'] - 0.2))
        baseline_runtime = overlap_data[baseline_idx]['runtime_mean']
        baseline_hv = overlap_data[baseline_idx]['hv_mean']
        
        print(f"\n  Baseline (≈20% overlap): Runtime={baseline_runtime:.2f}s, HV={baseline_hv:.4f}")
        
        # Criterion 1: Runtime multiplier > 2x
        runtime_threshold = None
        for i, data in enumerate(overlap_data):
            runtime_multiplier = data['runtime_mean'] / baseline_runtime
            if runtime_multiplier > 2.0:
                runtime_threshold = data['overlap_pct']
                print(f"  ⚠️ Runtime threshold: {runtime_threshold:.0f}% (multiplier: {runtime_multiplier:.2f}x)")
                break
        
        if runtime_threshold is None:
            print(f"  ✓ Runtime stays reasonable throughout (max: {max([d['runtime_mean']/baseline_runtime for d in overlap_data]):.2f}x)")
        
        # Criterion 2: Redundancy ratio > 2.5x
        redundancy_threshold = None
        for i, data in enumerate(overlap_data):
            if data['redundancy_mean'] > 2.5:
                redundancy_threshold = data['overlap_pct']
                print(f"  ⚠️ Redundancy threshold: {redundancy_threshold:.0f}% (redundancy: {data['redundancy_mean']:.2f}x)")
                break
        
        if redundancy_threshold is None:
            print(f"  ✓ Redundancy stays reasonable (max: {max([d['redundancy_mean'] for d in overlap_data]):.2f}x)")
        
        # Criterion 3: Marginal efficiency (ΔHV / ΔRuntime)
        marginal_efficiencies = []
        efficiency_threshold = None
        
        for i in range(1, len(overlap_data)):
            delta_hv = overlap_data[i]['hv_mean'] - overlap_data[i-1]['hv_mean']
            delta_runtime = overlap_data[i]['runtime_mean'] - overlap_data[i-1]['runtime_mean']
            
            if delta_runtime > 0:
                efficiency = delta_hv / delta_runtime
                marginal_efficiencies.append({
                    'overlap_pct': overlap_data[i]['overlap_pct'],
                    'efficiency': efficiency,
                    'delta_hv': delta_hv,
                    'delta_runtime': delta_runtime,
                })
        
        if len(marginal_efficiencies) > 2:
            # Find where efficiency drops significantly
            mean_efficiency = np.mean([e['efficiency'] for e in marginal_efficiencies[:3]])
            
            for eff_data in marginal_efficiencies[2:]:
                if eff_data['efficiency'] < 0.5 * mean_efficiency:  # 50% drop
                    efficiency_threshold = eff_data['overlap_pct']
                    print(f"  ⚠️ Efficiency threshold: {efficiency_threshold:.0f}% "
                          f"(efficiency dropped to {eff_data['efficiency']:.6f} HV/s)")
                    break
            
            if efficiency_threshold is None:
                print(f"  ✓ Marginal efficiency stays reasonable")
        
        # Criterion 4: Convergence penalty (iterations increase)
        convergence_threshold = None
        baseline_conv = overlap_data[baseline_idx]['convergence_mean']
        
        for i, data in enumerate(overlap_data):
            conv_multiplier = data['convergence_mean'] / baseline_conv
            if conv_multiplier > 1.5:
                convergence_threshold = data['overlap_pct']
                print(f"  ⚠️ Convergence threshold: {convergence_threshold:.0f}% "
                      f"(iterations: {conv_multiplier:.2f}x baseline)")
                break
        
        if convergence_threshold is None:
            print(f"  ✓ Convergence stays efficient")
        
        # Determine overall harm threshold (most conservative)
        thresholds = [t for t in [runtime_threshold, redundancy_threshold, 
                                   efficiency_threshold, convergence_threshold] 
                     if t is not None]
        
        if thresholds:
            harm_threshold = min(thresholds)
            print(f"\n  🎯 OVERALL HARM THRESHOLD: {harm_threshold:.0f}%")
            print(f"     (Most conservative of detected thresholds)")
        else:
            harm_threshold = 100.0
            print(f"\n  ✅ NO HARM DETECTED up to 100% overlap")
        
        # Store analysis
        harm_analysis[decomp_size] = {
            'harm_threshold': harm_threshold,
            'runtime_threshold': runtime_threshold,
            'redundancy_threshold': redundancy_threshold,
            'efficiency_threshold': efficiency_threshold,
            'convergence_threshold': convergence_threshold,
            'baseline_runtime': baseline_runtime,
            'baseline_hv': baseline_hv,
            'overlap_data': overlap_data,
            'marginal_efficiencies': marginal_efficiencies,
        }
    
    # Summary table
    print("\n" + "="*80)
    print("HARM THRESHOLD SUMMARY")
    print("="*80)
    print(f"{'D':>4} | {'Overall':>8} | {'Runtime':>8} | {'Redundancy':>10} | "
          f"{'Efficiency':>10} | {'Convergence':>11}")
    print("-"*80)
    
    for decomp_size in sorted(harm_analysis.keys()):
        analysis = harm_analysis[decomp_size]
        
        def fmt_threshold(t):
            return f"{t:.0f}%" if t is not None else "N/A"
        
        print(f"{decomp_size:>4} | {fmt_threshold(analysis['harm_threshold']):>8} | "
              f"{fmt_threshold(analysis['runtime_threshold']):>8} | "
              f"{fmt_threshold(analysis['redundancy_threshold']):>10} | "
              f"{fmt_threshold(analysis['efficiency_threshold']):>10} | "
              f"{fmt_threshold(analysis['convergence_threshold']):>11}")
    
    print("="*80)
    
    return harm_analysis


def plot_harm_detection(
    results, decomposition_sizes, overlap_ratios, harm_analysis,
    problem_type, problem_size, output_dir, timestamp, actual_size
):
    """
    Create comprehensive harm detection visualizations:
    1. Redundancy ratio vs overlap
    2. Marginal efficiency vs overlap
    3. Runtime multiplier vs overlap
    4. Convergence penalty vs overlap
    """
    
    print("\n" + "="*80)
    print("GENERATING HARM DETECTION PLOTS")
    print("="*80)
    
    # Create 2x2 subplot figure
    fig, axes = plt.subplots(2, 2, figsize=(14, 10), constrained_layout=False)
    axes = axes.flatten()
    
    # Color palette
    n_colors = len(decomposition_sizes)
    if n_colors <= 10:
        colors = plt.cm.tab10(np.linspace(0, 1, n_colors))
    else:
        colors = plt.cm.viridis(np.linspace(0, 1, n_colors))
    
    # Plot 1: Redundancy Ratio
    ax = axes[0]
    for i, decomp_size in enumerate(decomposition_sizes):
        if decomp_size not in harm_analysis:
            continue
        
        data = harm_analysis[decomp_size]['overlap_data']
        overlaps = [d['overlap_pct'] for d in data]
        redundancies = [d['redundancy_mean'] for d in data]
        redundancy_stds = [d['redundancy_std'] for d in data]
        
        ax.plot(overlaps, redundancies, 'o-', 
               color=colors[i], label=f'D={decomp_size}',
               linewidth=2, markersize=6, alpha=0.85)
        ax.fill_between(overlaps, 
                        np.array(redundancies) - np.array(redundancy_stds),
                        np.array(redundancies) + np.array(redundancy_stds),
                        alpha=0.12, color=colors[i])
        
        # Mark harm threshold
        threshold = harm_analysis[decomp_size]['redundancy_threshold']
        if threshold is not None:
            ax.axvline(threshold, color=colors[i], linestyle='--', alpha=0.4, linewidth=1.5)
    
    ax.axhline(2.5, color='red', linestyle='--', alpha=0.5, linewidth=2, label='Harm Threshold (2.5×)')
    ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=10)
    ax.set_ylabel('Redundancy Ratio (Total Work / Problem Size)', fontsize=10)
    ax.set_title('Computational Redundancy (↓ better)', fontsize=11, fontweight='bold', pad=8)
    ax.legend(fontsize=8, framealpha=0.9, edgecolor='0.8', ncol=2)
    ax.grid(True, alpha=0.3, linestyle='--')
    
    # Plot 2: Marginal Efficiency
    ax = axes[1]
    for i, decomp_size in enumerate(decomposition_sizes):
        if decomp_size not in harm_analysis:
            continue
        
        marginal_data = harm_analysis[decomp_size]['marginal_efficiencies']
        if not marginal_data:
            continue
        
        overlaps = [d['overlap_pct'] for d in marginal_data]
        efficiencies = [d['efficiency'] for d in marginal_data]
        
        ax.plot(overlaps, efficiencies, 'o-', 
               color=colors[i], label=f'D={decomp_size}',
               linewidth=2, markersize=6, alpha=0.85)
        
        # Mark efficiency threshold
        threshold = harm_analysis[decomp_size]['efficiency_threshold']
        if threshold is not None:
            ax.axvline(threshold, color=colors[i], linestyle='--', alpha=0.4, linewidth=1.5)
    
    ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=10)
    ax.set_ylabel('Marginal Efficiency (ΔHV / ΔRuntime)', fontsize=10)
    ax.set_title('Marginal Benefit Analysis (↑ better)', fontsize=11, fontweight='bold', pad=8)
    ax.legend(fontsize=8, framealpha=0.9, edgecolor='0.8', ncol=2)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_yscale('symlog', linthresh=0.001)  # Handle potential negative values
    
    # Plot 3: Runtime Multiplier
    ax = axes[2]
    for i, decomp_size in enumerate(decomposition_sizes):
        if decomp_size not in harm_analysis:
            continue
        
        data = harm_analysis[decomp_size]['overlap_data']
        baseline_runtime = harm_analysis[decomp_size]['baseline_runtime']
        
        overlaps = [d['overlap_pct'] for d in data]
        runtimes = [d['runtime_mean'] for d in data]
        runtime_stds = [d['runtime_std'] for d in data]
        
        # Calculate multipliers
        multipliers = [rt / baseline_runtime for rt in runtimes]
        multiplier_stds = [std / baseline_runtime for std in runtime_stds]
        
        ax.plot(overlaps, multipliers, 'o-', 
               color=colors[i], label=f'D={decomp_size}',
               linewidth=2, markersize=6, alpha=0.85)
        ax.fill_between(overlaps, 
                        np.array(multipliers) - np.array(multiplier_stds),
                        np.array(multipliers) + np.array(multiplier_stds),
                        alpha=0.12, color=colors[i])
        
        # Mark runtime threshold
        threshold = harm_analysis[decomp_size]['runtime_threshold']
        if threshold is not None:
            ax.axvline(threshold, color=colors[i], linestyle='--', alpha=0.4, linewidth=1.5)
    
    ax.axhline(2.0, color='red', linestyle='--', alpha=0.5, linewidth=2, label='Harm Threshold (2×)')
    ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=10)
    ax.set_ylabel('Runtime Multiplier (vs 20% baseline)', fontsize=10)
    ax.set_title('Runtime Penalty (↓ better)', fontsize=11, fontweight='bold', pad=8)
    ax.legend(fontsize=8, framealpha=0.9, edgecolor='0.8', ncol=2)
    ax.grid(True, alpha=0.3, linestyle='--')
    
    # Plot 4: Convergence Penalty
    ax = axes[3]
    for i, decomp_size in enumerate(decomposition_sizes):
        if decomp_size not in harm_analysis:
            continue
        
        data = harm_analysis[decomp_size]['overlap_data']
        
        # Find baseline convergence (20% overlap)
        baseline_idx = min(range(len(data)), 
                          key=lambda j: abs(data[j]['overlap_ratio'] - 0.2))
        baseline_conv = data[baseline_idx]['convergence_mean']
        
        overlaps = [d['overlap_pct'] for d in data]
        convergences = [d['convergence_mean'] for d in data]
        conv_stds = [d['convergence_std'] for d in data]
        
        # Calculate multipliers
        multipliers = [conv / baseline_conv for conv in convergences]
        multiplier_stds = [std / baseline_conv for std in conv_stds]
        
        ax.plot(overlaps, multipliers, 'o-', 
               color=colors[i], label=f'D={decomp_size}',
               linewidth=2, markersize=6, alpha=0.85)
        ax.fill_between(overlaps, 
                        np.array(multipliers) - np.array(multiplier_stds),
                        np.array(multipliers) + np.array(multiplier_stds),
                        alpha=0.12, color=colors[i])
        
        # Mark convergence threshold
        threshold = harm_analysis[decomp_size]['convergence_threshold']
        if threshold is not None:
            ax.axvline(threshold, color=colors[i], linestyle='--', alpha=0.4, linewidth=1.5)
    
    ax.axhline(1.5, color='red', linestyle='--', alpha=0.5, linewidth=2, label='Harm Threshold (1.5×)')
    ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=10)
    ax.set_ylabel('Convergence Multiplier (vs 20% baseline)', fontsize=10)
    ax.set_title('Convergence Slowdown (↓ better)', fontsize=11, fontweight='bold', pad=8)
    ax.legend(fontsize=8, framealpha=0.9, edgecolor='0.8', ncol=2)
    ax.grid(True, alpha=0.3, linestyle='--')
    
    # Overall title
    fig.suptitle(
        f'Overlap Harm Detection Analysis ({problem_type}{actual_size})\n'
        f'Dashed vertical lines = threshold where harm begins',
        fontsize=13, fontweight='bold'
    )
    
    # Manual spacing
    plt.subplots_adjust(top=0.92, bottom=0.07, left=0.08, right=0.97, 
                       hspace=0.28, wspace=0.25)
    
    # Save
    plot_file = os.path.join(
        output_dir,
        f'harm_detection_{problem_type}_{problem_size}_{timestamp}.png'
    )
    plt.savefig(plot_file, dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight', pad_inches=0.1)
    print(f"Saved: {plot_file}")
    plt.close()


def save_harm_detection_to_csv(
    results, harm_analysis, decomposition_sizes, overlap_ratios,
    problem_type, problem_size, output_dir, timestamp
):
    """Save harm detection results to CSV"""
    
    # 1. Detailed metrics CSV
    detailed_data = []
    for decomp_size in decomposition_sizes:
        for overlap_ratio in overlap_ratios:
            if overlap_ratio in results[decomp_size]:
                runs = results[decomp_size][overlap_ratio]
                
                for run_idx, run in enumerate(runs):
                    detailed_data.append({
                        'problem_type': problem_type,
                        'problem_size': problem_size,
                        'decomposition_size': decomp_size,
                        'overlap_ratio': overlap_ratio,
                        'overlap_percentage': overlap_ratio * 100,
                        'run_number': run_idx + 1,
                        'hypervolume': run['hypervolume'],
                        'runtime': run['runtime'],
                        'wall_time': run['wall_time'],
                        'num_solutions': run['num_solutions'],
                        'tour_length': run.get('tour_length', np.nan),
                        'n_subproblems': run['n_subproblems'],
                        'convergence_iter': run['convergence_iter'],
                        'total_work': run['total_work'],
                        'theoretical_redundancy': run['theoretical_redundancy'],
                        'empirical_redundancy': run['empirical_redundancy'],
                        'stagnation_count': run['stagnation_count'],
                        'stride': run['stride'],
                        'timestamp': timestamp
                    })
    
    df_detailed = pd.DataFrame(detailed_data)
    csv_detailed = os.path.join(
        output_dir,
        f'harm_detection_detailed_{problem_type}_{problem_size}_{timestamp}.csv'
    )
    df_detailed.to_csv(csv_detailed, index=False)
    print(f"Detailed metrics saved to: {csv_detailed}")
    
    # 2. Harm thresholds summary CSV
    threshold_data = []
    for decomp_size in sorted(harm_analysis.keys()):
        analysis = harm_analysis[decomp_size]
        threshold_data.append({
            'problem_type': problem_type,
            'problem_size': problem_size,
            'decomposition_size': decomp_size,
            'overall_harm_threshold': analysis['harm_threshold'],
            'runtime_threshold': analysis.get('runtime_threshold'),
            'redundancy_threshold': analysis.get('redundancy_threshold'),
            'efficiency_threshold': analysis.get('efficiency_threshold'),
            'convergence_threshold': analysis.get('convergence_threshold'),
            'baseline_runtime': analysis['baseline_runtime'],
            'baseline_hv': analysis['baseline_hv'],
            'timestamp': timestamp
        })
    
    df_thresholds = pd.DataFrame(threshold_data)
    csv_thresholds = os.path.join(
        output_dir,
        f'harm_thresholds_summary_{problem_type}_{problem_size}_{timestamp}.csv'
    )
    df_thresholds.to_csv(csv_thresholds, index=False)
    print(f"Harm thresholds saved to: {csv_thresholds}")
    
    # 3. Marginal efficiency CSV
    efficiency_data = []
    for decomp_size in sorted(harm_analysis.keys()):
        if 'marginal_efficiencies' in harm_analysis[decomp_size]:
            for eff_point in harm_analysis[decomp_size]['marginal_efficiencies']:
                efficiency_data.append({
                    'problem_type': problem_type,
                    'problem_size': problem_size,
                    'decomposition_size': decomp_size,
                    'overlap_percentage': eff_point['overlap_pct'],
                    'marginal_efficiency': eff_point['efficiency'],
                    'delta_hypervolume': eff_point['delta_hv'],
                    'delta_runtime': eff_point['delta_runtime'],
                    'timestamp': timestamp
                })
    
    if efficiency_data:
        df_efficiency = pd.DataFrame(efficiency_data)
        csv_efficiency = os.path.join(
            output_dir,
            f'marginal_efficiency_{problem_type}_{problem_size}_{timestamp}.csv'
        )
        df_efficiency.to_csv(csv_efficiency, index=False)
        print(f"Marginal efficiency saved to: {csv_efficiency}")
    
    return csv_detailed, csv_thresholds



# ============================================================================
# UPDATED EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    
    # Example 1: Overlap ablation
    # print("\n" + "="*80)
    # print("EXAMPLE 1: OVERLAP ABLATION")
    # print("="*80)
    
    # algorithms = {
    #     'UCB-Exp3': UCBWrapper,
    #     #  'Thompson-Exp3': ThompsonWrapper,
    # }
    
    # # Optional: Define baseline solver results for reference lines
    # baseline_solvers = {
    #     'WS': {'hypervolume': 0.69, 'runtime': 72},
    #     # 'NSGA-II': {'hypervolume': 0.48, 'runtime': 8.5}
    # }

    # results_overlap = ablation_overlap_vs_decomposition(
    #     algorithm_classes=algorithms,
    #     problem_type='BiTSP',
    #     problem_size='large',
    #     decomposition_sizes= [10, 15, 25, 35, 40, 50, 70, 90, 100], # small: [5, 10, 15, 20], # large: [10, 15, 25, 40, 50, 70, 90]
    #     overlap_ratios=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8 , 0.9, 1.0],
    #     num_runs=50,
    #     output_dir='ablation_results_complete_overlap_UCB',
    #     plot_metrics=['hypervolume', 'runtime', 'tour_length', 'solutions'],
    #     baseline_solvers=baseline_solvers  # Add baseline reference lines
    # )
    
    # =========================================================================
    # Example 2: POST-RUN REPLOTTING
    # =========================================================================
    # Uncomment to regenerate plots from saved YAML results with new settings

    # print("\n" + "="*80)
    # print("EXAMPLE 2: REPLOTTING FROM SAVED RESULTS")
    # print("="*80)

    # # Define new baseline solvers (can be different from original run)
    # baseline_solvers_new = {
    #     'WS': {'hypervolume': 0.52, 'runtime': 9.5},
        # 'NSGA-II': {'hypervolume': 0.48, 'runtime': 8.5},
        # 'MOEA/D': {'hypervolume': 0.49, 'runtime': 11.0}
    # }

    # Load and replot with new settings
    # reloaded_results = load_and_replot(
    #     yaml_file='ablation_results_complete_overlap_thompson/ablation_overlap_decomp_BiTSP_20251126-153709.yaml',
    #     plot_metrics=['hypervolume', 'runtime', 'tour_length', 'solutions'],  # Can change metrics
    #     baseline_solvers=baseline_solvers,     # Add/modify baselines
    #     output_suffix='re_gen'             # Custom suffix for new plots
    # )

    # =========================================================================
    # Example 3: Decomposition diagnostic with convergence
    # =========================================================================
    # print("\n" + "="*80)
    # print("EXAMPLE 3: DECOMPOSITION DIAGNOSTIC WITH CONVERGENCE TRACKING")
    # print("="*80)
    
    # results_decomp = decomposition_diagnostic_with_convergence(
    #     algorithm_class=UCBWrapper,
    #     problem_type='BiTSP',
    #     problem_size='medium',
    #     decomp_sizes=[5, 10, 15, 20, 30, 45],
    #     num_runs=1,
    #     output_dir='diagnostic_results',
    #     plot_options={
    #         'n_subproblems': True,
    #         'main_iterations': True,
    #         'convergence_iterations': True,  # ⭐ U-curve expected!
    #         'total_work': True,              # ⭐ U-curve expected!
    #         'wall_time': True,
    #         'hypervolume': True,
    #         'tour_length': True,
    #         'efficiency': False
    #     }
    # )
    
    # print("\n" + "="*80)
    print("DONE! Check these folders:")
    print("  - ablation_results/ (overlap ablation)")
    print("  - diagnostic_results/ (decomposition with U-curves!)")
    print("="*80)

    # NEW: Harm detection experiment
    print("\n" + "="*80)
    print("EXPERIMENT: OVERLAP HARM DETECTION")
    print("="*80)
    
    results_harm, harm_analysis = experiment_overlap_harm_detection(
        algorithm_class=UCBWrapper,
        problem_type='BiTSP',
        problem_size='large',
        decomposition_sizes= [10, 20, 30, 40, 50, 80, 100], # small: [5, 10, 15, 20], # large: [10, 20, 30, 50],
        overlap_ratios=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
        num_runs=50,
        output_dir='harm_detection_results'
    )
    
    print("\n" + "="*80)
    print("HARM DETECTION COMPLETE!")
    print("Check harm_detection_results/ for:")
    print("  - Comprehensive plots (4 subplots)")
    print("  - Detailed metrics CSV")
    print("  - Harm thresholds summary CSV")
    print("  - Marginal efficiency CSV")
    print("="*80)
