import os
import yaml
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any, List, Tuple, Optional
from collections import defaultdict
import torch
import warnings
import pickle
import json
import glob

warnings.filterwarnings('ignore')

import sys
from pathlib import Path

# Add project root to path (works from any subdirectory)
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

# Import MOCO components
from MOCO.problems import BiObjectiveTSP, MultiObjectiveKnapsack, TriObjectiveTSP
from MOCO.evaluation import MOCOEvaluator

# Import algorithm wrappers
from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import (
    CachedAdvancedBiKPWrapper as UCBWrapper
)
from our_method_dl_Thompson_variant import (
    CachedAdvancedBiKPWrapper as ThompsonWrapper
)

# ============================================================================
# ICML STYLE CONFIGURATION
# ============================================================================

def setup_icml_style():
    """Configure matplotlib for ICML paper style"""
    
    # ICML style uses sans-serif fonts (Helvetica/Arial)
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("husl")  # <-- This is the color palette
    # Color palette - professional and colorblind-friendly
    # colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161', '#949494']
    # sns.set_palette(sns.color_palette(colors))
    
    plt.rcParams.update({
        # Font settings - use sans-serif for ICML
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica', 'Liberation Sans'],
        'font.size': 9,
        'axes.labelsize': 10,
        'axes.titlesize': 11,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        
        # Figure settings
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        
        # Layout
        'figure.constrained_layout.use': True,
        'figure.autolayout': False,
        
        # Grid and spines
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 0.8,
        
        # Lines and markers
        'lines.linewidth': 1.5,
        'lines.markersize': 6,
        
        # Legend
        'legend.frameon': True,
        'legend.framealpha': 0.9,
        'legend.edgecolor': '0.8',
        
        # Other
        'axes.axisbelow': True,
    })

setup_icml_style()

# ============================================================================
# PROBLEM CONFIGURATIONS
# ============================================================================

PROBLEM_REGISTRY = {
    'BiTSP': {
        'class': BiObjectiveTSP,
        'ref_type': 'BiTSP',
        'sizes': {
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100},
        },
        'default_ref': lambda size: (20, 20) if size == 20 else (35, 35) if size == 50 else (65, 65)
    },
    'BiKP': {
        'class': MultiObjectiveKnapsack,
        'ref_type': 'BiKP',
        'sizes': {
            'small': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5},
            'medium': {'n_items': 100, 'n_objectives': 2, 'capacity': 25.0},
            'large': {'n_items': 200, 'n_objectives': 2, 'capacity': 50.0},
        },
        'default_ref': lambda size: (5, 5) if size == 50 else (20, 20) if size == 100 else (30, 30)
    },
    'TriTSP': {
        'class': TriObjectiveTSP,
        'ref_type': 'TriTSP',
        'sizes': {
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100},
        },
        'default_ref': lambda size: (20, 20, 20) if size == 20 else (35, 35, 35) if size == 50 else (65, 65, 65)
    }
}


# ============================================================================
# BEST HYPERPARAMETER CONFIGURATIONS (consolidated from ablation studies)
# ============================================================================

BEST_CONFIG = {
    'BiTSP': {
        'small': {
            'decomposition_size': 15,
            'overlap': 6,
            'n_weight_vectors': 15,
            'nb_rounds': 5,
            'patience': 20
        },
        'medium': {
            'decomposition_size': 25,
            'overlap': 10,
            'n_weight_vectors': 25,
            'nb_rounds': 10,
            'patience': 50
        },
        'large': {
            'decomposition_size': 35,
            'overlap': 15,
            'n_weight_vectors': 30,
            'nb_rounds': 30,
            'patience': 100
        }
    },
    'BiKP': {
        'small': {
            'decomposition_size': 15,
            'overlap': 6,
            'n_weight_vectors': 20,
            'nb_rounds': 5,
            'patience': 30
        },
        'medium': {
            'decomposition_size': 30,
            'overlap': 12,
            'n_weight_vectors': 20,
            'nb_rounds': 5,
            'patience': 20
        },
        'large': {
            'decomposition_size': 50,
            'overlap': 20,
            'n_weight_vectors': 30,
            'nb_rounds': 5,
            'patience': 30
        }
    },
    'TriTSP': {
        'small': {
            'decomposition_size': 15,
            'overlap': 6,
            'n_weight_vectors': 15,
            'nb_rounds': 5,
            'patience': 20
        },
        'medium': {
            'decomposition_size': 25,
            'overlap': 10,
            'n_weight_vectors': 25,
            'nb_rounds': 5,
            'patience': 20
        },
        'large': {
            'decomposition_size': 35,
            'overlap': 15,
            'n_weight_vectors': 35,
            'nb_rounds': 5,
            'patience': 30
        }
    }
}

# ============================================================================
# GLOBAL BASE PARAMETERS (shared across all ablation studies)
# ============================================================================

BASE_ALGORITHM_PARAMS = {
    'learning_rate': 0.5,
    'initial_temperature': 1.0,
    'temp_decay': 0.98,
    'hybrid_ratio': 0.7,
    'adaptive_hybrid': True,
    'max_iterations': 80,
    'use_lagrangian': True,
    'use_ftrl': True,
    'dual_step_size': 1.0,
    'use_accelerated_dual': True,
    'use_diminishing_overlap': True,
    'overlap_decay_rate': 0.1,
}

def get_base_params(problem_type: str = None, problem_size: str = None, **overrides) -> Dict[str, Any]:
    """
    Get base algorithm parameters with optional overrides from best config

    Parameters:
    -----------
    problem_type : str, optional
        Problem type to get best config for
    problem_size : str, optional
        Problem size to get best config for
    **overrides : dict
        Additional parameters to override

    Returns:
    --------
    dict : Complete parameter dictionary
    """
    params = BASE_ALGORITHM_PARAMS.copy()

    # Add best config if problem type and size specified
    if problem_type and problem_size and problem_type in BEST_CONFIG:
        if problem_size in BEST_CONFIG[problem_type]:
            params.update(BEST_CONFIG[problem_type][problem_size])

    # Apply any additional overrides
    params.update(overrides)

    return params

# ============================================================================
# ABLATION STUDY 1: DECOMPOSITION SIZE
# ============================================================================

def ablation_decomposition_size(
    problem_type: str = 'BiTSP',
    problem_size: str = 'medium',
    decomposition_sizes: List[int] = None,
    num_runs: int = 5,
    output_dir: str = 'ablation_results',
    metrics_to_plot: List[str] = ['hypervolume', 'runtime', 'solutions'],
    plot_mode: str = 'grouped',  # 'grouped' or 'separate'
    figure_size: Tuple[float, float] = None
):
    """
    Ablation Study 1: Effect of decomposition size on both UCB and Thompson Sampling
    
    Parameters:
    -----------
    problem_type : str
        One of 'BiTSP', 'BiKP', 'TriTSP'
    problem_size : str
        One of 'small', 'medium', 'large'
    decomposition_sizes : List[int]
        List of decomposition sizes to test (if None, auto-generate based on problem size)
    num_runs : int
        Number of runs per configuration
    output_dir : str
        Directory to save results
    metrics_to_plot : List[str]
        Metrics to plot. Options: 'hypervolume', 'runtime', 'solutions'
    plot_mode : str
        'grouped' for single figure with subplots, 'separate' for individual files
    figure_size : Tuple[float, float]
        Custom figure size (width, height). If None, auto-determined based on plot_mode
    """
    
    print("\n" + "="*80)
    print("ABLATION STUDY 1: DECOMPOSITION SIZE ANALYSIS")
    print("="*80)
    
    # Validate metrics
    valid_metrics = ['hypervolume', 'runtime', 'solutions']
    metrics_to_plot = [m for m in metrics_to_plot if m in valid_metrics]
    if not metrics_to_plot:
        raise ValueError(f"No valid metrics selected. Choose from: {valid_metrics}")
    
    # Get problem configuration
    if problem_type not in PROBLEM_REGISTRY:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    problem_config = PROBLEM_REGISTRY[problem_type]
    problem_class = problem_config['class']
    problem_params = problem_config['sizes'][problem_size]
    
    # Determine problem size parameter
    if 'n_cities' in problem_params:
        actual_size = problem_params['n_cities']
    elif 'n_items' in problem_params:
        actual_size = problem_params['n_items']
    else:
        actual_size = 50
    
    # Auto-generate decomposition sizes if not provided
    if decomposition_sizes is None:
        decomposition_sizes = [
            max(5, actual_size // 10),   # 10% of problem size
            max(10, actual_size // 5),   # 20%
            max(15, actual_size // 3),   # 33%
            max(20, actual_size // 2),   # 50%
            max(30, 2 * actual_size // 3)  # 67%
        ]
        # Remove duplicates and sort
        decomposition_sizes = sorted(list(set([d for d in decomposition_sizes if d < actual_size])))
    
    print(f"\nProblem: {problem_type} ({problem_size})")
    print(f"Actual size: {actual_size}")
    print(f"Testing decomposition sizes: {decomposition_sizes}")
    print(f"Runs per configuration: {num_runs}")
    print(f"Metrics to plot: {metrics_to_plot}")
    print(f"Plot mode: {plot_mode}")
    
    # Get reference point
    ref_point = problem_config['default_ref'](actual_size)
    print(f"Reference point: {ref_point}")
    
    # Storage for results
    results = {
        'UCB': defaultdict(list),
        'Thompson': defaultdict(list)
    }
    
    # Test each decomposition size
    for decomp_size in decomposition_sizes:
        print(f"\n{'-'*60}")
        print(f"Testing decomposition size: {decomp_size}")
        print(f"{'-'*60}")
        
        # Calculate overlap (approximately 30% of decomposition size)
        overlap = max(2, decomp_size // 3)

        # Get base parameters with overrides for this specific test
        base_params = get_base_params(
            decomposition_size=decomp_size,
            overlap=overlap,
            n_weight_vectors=15,
            nb_rounds=5,
            patience=20
        )
        
        # Test UCB
        print(f"\n  Testing UCB...")
        ucb_params = base_params.copy()
        ucb_params['ucb_coefficient'] = 3.0
        
        for run in range(num_runs):
            try:
                evaluator = MOCOEvaluator(reference_point=ref_point)
                result = evaluator.evaluate_algorithm(
                    algorithm_class=UCBWrapper,
                    problem_class=problem_class,
                    algorithm_name=f"UCB_decomp{decomp_size}",
                    parameters=ucb_params,
                    problem_params=problem_params,
                    num_runs=1
                )
                
                results['UCB'][decomp_size].append({
                    'hypervolume': result.hypervolume,
                    'runtime': result.runtime,
                    'num_solutions': result.num_nondominated
                })
                
                print(f"    Run {run+1}/{num_runs} - HV: {result.hypervolume:.4f}, Time: {result.runtime:.2f}s")
                
            except Exception as e:
                print(f"    Run {run+1} failed: {e}")
        
        # Test Thompson Sampling
        print(f"\n  Testing Thompson Sampling...")
        thompson_params = base_params.copy()
        # Thompson doesn't use ucb_coefficient
        
        for run in range(num_runs):
            try:
                evaluator = MOCOEvaluator(reference_point=ref_point)
                result = evaluator.evaluate_algorithm(
                    algorithm_class=ThompsonWrapper,
                    problem_class=problem_class,
                    algorithm_name=f"Thompson_decomp{decomp_size}",
                    parameters=thompson_params,
                    problem_params=problem_params,
                    num_runs=1
                )
                
                results['Thompson'][decomp_size].append({
                    'hypervolume': result.hypervolume,
                    'runtime': result.runtime,
                    'num_solutions': result.num_nondominated
                })
                
                print(f"    Run {run+1}/{num_runs} - HV: {result.hypervolume:.4f}, Time: {result.runtime:.2f}s")
                
            except Exception as e:
                print(f"    Run {run+1} failed: {e}")
    
    # Save results
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    os.makedirs(output_dir, exist_ok=True)
    
    results_file = os.path.join(output_dir, f'ablation1_decomp_{problem_type}_{problem_size}_{timestamp}.yaml')
    with open(results_file, 'w') as f:
        yaml.dump({
            'problem_type': problem_type,
            'problem_size': problem_size,
            'actual_size': actual_size,
            'decomposition_sizes': decomposition_sizes,
            'num_runs': num_runs,
            'results': {alg: dict(res) for alg, res in results.items()}
        }, f, default_flow_style=False)
    
    print(f"\nResults saved to: {results_file}")
    
    # Generate plots
    plot_decomposition_ablation(
        results, decomposition_sizes, problem_type, problem_size, 
        output_dir, timestamp, metrics_to_plot, plot_mode, figure_size, actual_size
    )
    
    return results


def plot_decomposition_ablation(
    results, decomp_sizes, problem_type, problem_size, 
    output_dir, timestamp, metrics_to_plot, plot_mode='grouped', 
    figure_size=None, actual_size=None  # ADD THIS
):
    """Generate ICML-style plots for decomposition size ablation"""
    
    print("\nGenerating plots...")
    
    # Prepare data
    data_ucb = []
    data_thompson = []
    
    for decomp_size in decomp_sizes:
        for run_result in results['UCB'][decomp_size]:
            data_ucb.append({
                'Decomposition Size': decomp_size,
                'Hypervolume': run_result['hypervolume'],
                'Runtime (s)': run_result['runtime'],
                'Solutions': run_result['num_solutions'],
                'Algorithm': 'UCB-Hedge'
            })
        
        for run_result in results['Thompson'][decomp_size]:
            data_thompson.append({
                'Decomposition Size': decomp_size,
                'Hypervolume': run_result['hypervolume'],
                'Runtime (s)': run_result['runtime'],
                'Solutions': run_result['num_solutions'],
                'Algorithm': 'Thompson-Hedge'
            })
    
    df = pd.DataFrame(data_ucb + data_thompson)
    
    # Define metric configurations
    metric_config = {
        'hypervolume': {
            'column': 'Hypervolume',
            'ylabel': 'Hypervolume',
            'title': 'Hypervolume vs Decomposition Size',
            'marker': 'o'
        },
        'runtime': {
            'column': 'Runtime (s)',
            'ylabel': 'Runtime (seconds)',
            'title': 'Runtime vs Decomposition Size',
            'marker': 's'
        },
        'solutions': {
            'column': 'Solutions',
            'ylabel': 'Number of Solutions',
            'title': 'Pareto Front Size vs Decomposition Size',
            'marker': '^'
        }
    }
    
    n_metrics = len(metrics_to_plot)
    
    if plot_mode == 'grouped':
        # Single figure with subplots
        if figure_size is None:
            figure_size = (5 * n_metrics, 3.5)
        
        fig, axes = plt.subplots(1, n_metrics, figsize=figure_size)
        if n_metrics == 1:
            axes = [axes]
        
        for idx, metric in enumerate(metrics_to_plot):
            config = metric_config[metric]
            
            sns.lineplot(
                data=df, x='Decomposition Size', y=config['column'],
                hue='Algorithm', marker=config['marker'], ax=axes[idx],
                linewidth=2, markersize=7, err_style='band', errorbar='sd'
            )
            
            axes[idx].set_title(config['title'], fontweight='bold', pad=10)
            axes[idx].set_ylabel(config['ylabel'])
            axes[idx].set_xlabel('Decomposition Size')
            axes[idx].legend(frameon=True, loc='best')
            axes[idx].grid(True, alpha=0.3, linestyle='--')
        
        # Overall title
        fig.suptitle(
            f'Decomposition Size Effect: {problem_type} (n={actual_size})',
            fontsize=12, fontweight='bold', y=0.98
        )
        
        plt.tight_layout()
        
        # Save grouped plot
        plot_file = os.path.join(
            output_dir, 
            f'ablation1_decomp_{problem_type}_{problem_size}_grouped_{timestamp}.png'
        )
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
        print(f"Grouped plot saved to: {plot_file}")
        plt.close()
    
    else:  # separate mode
        # Individual figures for each metric
        if figure_size is None:
            figure_size = (5, 3.5)
        
        for metric in metrics_to_plot:
            config = metric_config[metric]
            
            fig, ax = plt.subplots(1, 1, figsize=figure_size)
            
            sns.lineplot(
                data=df, x='Decomposition Size', y=config['column'],
                hue='Algorithm', marker=config['marker'], ax=ax,
                linewidth=2, markersize=7, err_style='band', errorbar='sd'
            )
            
            ax.set_title(
                f'{config["title"]}\n{problem_type} (n={actual_size})',
                fontweight='bold', pad=10
            )
            ax.set_ylabel(config['ylabel'])
            ax.set_xlabel('Decomposition Size')
            ax.legend(frameon=True, loc='best')
            ax.grid(True, alpha=0.3, linestyle='--')
            
            plt.tight_layout()
            
            # Save individual plot
            plot_file = os.path.join(
                output_dir,
                f'ablation1_decomp_{problem_type}_{problem_size}_{metric}_{timestamp}.png'
            )
            plt.savefig(plot_file, dpi=300, bbox_inches='tight')
            plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
            print(f"{metric.capitalize()} plot saved to: {plot_file}")
            plt.close()
    
    # Create statistical summary table
    summary_stats = df.groupby(['Algorithm', 'Decomposition Size']).agg({
        'Hypervolume': ['mean', 'std'],
        'Runtime (s)': ['mean', 'std'],
        'Solutions': ['mean', 'std']
    }).round(4)
    
    print("\n" + "="*80)
    print("STATISTICAL SUMMARY")
    print("="*80)
    print(summary_stats)
    
    # Save summary table
    summary_file = os.path.join(
        output_dir, 
        f'ablation1_summary_{problem_type}_{problem_size}_{timestamp}.csv'
    )
    summary_stats.to_csv(summary_file)
    print(f"\nSummary statistics saved to: {summary_file}")


# ============================================================================
# ABLATION STUDY 2: FTRL ON/OFF
# ============================================================================

def ablation_ftrl_toggle(
    problem_type: str = 'BiTSP',
    problem_size: str = 'medium',
    num_runs: int = 10,
    output_dir: str = 'ablation_results',
    metrics_to_plot: List[str] = ['hypervolume', 'runtime', 'solutions'],
    plot_mode: str = 'grouped',
    figure_size: Tuple[float, float] = None,
    save_pareto_fronts: bool = False  # NEW PARAMETER
):
    """
    Ablation Study 2: Effect of FTRL component on both UCB and Thompson Sampling
    
    Parameters:
    -----------
    problem_type : str
        One of 'BiTSP', 'BiKP', 'TriTSP'
    problem_size : str
        One of 'small', 'medium', 'large'
    num_runs : int
        Number of runs per configuration
    output_dir : str
        Directory to save results
    metrics_to_plot : List[str]
        Metrics to plot. Options: 'hypervolume', 'runtime', 'solutions'
    plot_mode : str
        'grouped' for single figure with subplots, 'separate' for individual files
    figure_size : Tuple[float, float]
        Custom figure size (width, height)
    save_pareto_fronts : bool
        Whether to save Pareto front data after completion
    """
    
    print("\n" + "="*80)
    print("ABLATION STUDY 2: FTRL COMPONENT ANALYSIS")
    print("="*80)
    
    # Validate metrics
    valid_metrics = ['hypervolume', 'runtime', 'solutions']
    metrics_to_plot = [m for m in metrics_to_plot if m in valid_metrics]
    if not metrics_to_plot:
        raise ValueError(f"No valid metrics selected. Choose from: {valid_metrics}")
    
    # Get problem configuration
    if problem_type not in PROBLEM_REGISTRY:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    problem_config = PROBLEM_REGISTRY[problem_type]
    problem_class = problem_config['class']
    problem_params = problem_config['sizes'][problem_size]
    
    # Determine problem size parameter
    if 'n_cities' in problem_params:
        actual_size = problem_params['n_cities']
    elif 'n_items' in problem_params:
        actual_size = problem_params['n_items']
    else:
        actual_size = 50
    
    print(f"\nProblem: {problem_type} ({problem_size})")
    print(f"Actual size: {actual_size}")
    print(f"Runs per configuration: {num_runs}")
    print(f"Metrics to plot: {metrics_to_plot}")
    print(f"Plot mode: {plot_mode}")
    print(f"Save Pareto fronts: {save_pareto_fronts}")  # NEW
    
    # Get reference point
    ref_point = problem_config['default_ref'](actual_size)
    print(f"Reference point: {ref_point}")
    
    # Storage for results
    results = {
        'UCB_with_FTRL': [],
        'UCB_without_FTRL': [],
        'Thompson_with_FTRL': [],
        'Thompson_without_FTRL': []
    }
    
    


    # Get base parameters with best config for this problem
    base_params = get_base_params(
        problem_type=problem_type,
        problem_size=problem_size,
        learning_rate=0.1,  # Slightly lower for FTRL ablation
        max_iterations=250,
        use_learned_operators=True
    )
    
    # Test configurations
    # config_name, wrapper_class, use_ftrl, ucb_coef
    configs = [
        ('UCB_with_FTRL', UCBWrapper, True, 3.0),
        ('UCB_without_FTRL', UCBWrapper, False, 3.0),
        ('Thompson_with_FTRL', ThompsonWrapper, True, None),
        ('Thompson_without_FTRL', ThompsonWrapper, False, None),
    ]
    
    for config_name, wrapper_class, use_ftrl, ucb_coef in configs:
        print(f"\n{'-'*60}")
        print(f"Testing: {config_name}")
        print(f"{'-'*60}")
        
        # Set parameters
        params = base_params.copy()
        params['use_ftrl'] = use_ftrl
        if ucb_coef is not None:
            params['ucb_coefficient'] = ucb_coef
        
        for run in range(num_runs):
            try:
                evaluator = MOCOEvaluator(reference_point=ref_point)
                result = evaluator.evaluate_algorithm(
                    algorithm_class=wrapper_class,
                    problem_class=problem_class,
                    algorithm_name=config_name,
                    parameters=params,
                    problem_params=problem_params,
                    num_runs=1
                )
                
                results[config_name].append({
                    'hypervolume': result.hypervolume,
                    'runtime': result.runtime,
                    'num_solutions': result.num_nondominated,
                    'objectives': result.objectives  # Save actual Pareto front
                })
                
                print(f"  Run {run+1}/{num_runs} - HV: {result.hypervolume:.4f}, Time: {result.runtime:.2f}s")
                
            except Exception as e:
                print(f"  Run {run+1} failed: {e}")
    
    # Save results
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    os.makedirs(output_dir, exist_ok=True)
    
    results_file = os.path.join(
        output_dir, 
        f'ablation2_ftrl_{problem_type}_{problem_size}_{timestamp}.yaml'
    )
    with open(results_file, 'w') as f:
        yaml.dump({
            'problem_type': problem_type,
            'problem_size': problem_size,
            'actual_size': actual_size,
            'num_runs': num_runs,
            'results': results
        }, f, default_flow_style=False)
    
    print(f"\nResults saved to: {results_file}")
    
    # Generate plots
    plot_ftrl_ablation(
        results, problem_type, problem_size, 
        output_dir, timestamp, metrics_to_plot, plot_mode, figure_size, actual_size
    )
    
    # NEW: Save Pareto fronts if requested
    if save_pareto_fronts:
        print("\n" + "="*80)
        print("SAVING PARETO FRONT DATA")
        print("="*80)
        
        # Create Pareto subdirectory
        pareto_dir = os.path.join(output_dir, 'pareto_fronts')
        
        # Save Pareto front data from the just-created YAML file
        pareto_data = save_pareto_front_data(
            yaml_file=results_file,
            output_dir=pareto_dir,
            save_format='both'
        )
        
        print(f"\n✓ Pareto fronts saved to: {pareto_dir}/")
    
    return results


def plot_ftrl_ablation(
    results, problem_type, problem_size, 
    output_dir, timestamp, metrics_to_plot, plot_mode='grouped',
    figure_size=None, actual_size=None
):
    """Generate ICML-style plots for FTRL ablation"""
    
    print("\nGenerating plots...")
    
    # Prepare data
    data = []
    
    for config_name, runs in results.items():
        # Parse config name
        if 'UCB' in config_name:
            algorithm = 'UCB-Hedge'
        else:
            algorithm = 'Thompson-Hedge'
        
        if 'with_FTRL' in config_name:
            ftrl_status = 'With FTRL'
        else:
            ftrl_status = 'Without FTRL'
        
        for run_result in runs:
            data.append({
                'Algorithm': algorithm,
                'FTRL': ftrl_status,
                'Config': f"{algorithm}\n{ftrl_status}",
                'Hypervolume': run_result['hypervolume'],
                'Runtime (s)': run_result['runtime'],
                'Solutions': run_result['num_solutions']
            })
    
    df = pd.DataFrame(data)
    
    # Define metric configurations
    metric_config = {
        'hypervolume': {
            'column': 'Hypervolume',
            'ylabel': 'Hypervolume',
            'title': 'Hypervolume Comparison'
        },
        'runtime': {
            'column': 'Runtime (s)',
            'ylabel': 'Runtime (seconds)',
            'title': 'Runtime Comparison'
        },
        'solutions': {
            'column': 'Solutions',
            'ylabel': 'Number of Solutions',
            'title': 'Pareto Front Size Comparison'
        }
    }
    
    # ICML-friendly color palette
    palette = {'With FTRL': '#0173B2', 'Without FTRL': '#DE8F05'}
    
    n_metrics = len(metrics_to_plot)
    
    if plot_mode == 'grouped':
        # Single figure with subplots
        if figure_size is None:
            figure_size = (5 * n_metrics, 4)
        
        fig, axes = plt.subplots(1, n_metrics, figsize=figure_size)
        if n_metrics == 1:
            axes = [axes]
        
        for idx, metric in enumerate(metrics_to_plot):
            config = metric_config[metric]
            
            # Box plot
            sns.boxplot(
                data=df, x='Algorithm', y=config['column'], hue='FTRL',
                ax=axes[idx], palette=palette, linewidth=1.2, width=0.6
            )
            
            # Strip plot for individual points
            sns.stripplot(
                data=df, x='Algorithm', y=config['column'], hue='FTRL',
                ax=axes[idx], dodge=True, alpha=0.5, size=3, 
                palette=palette, legend=False
            )
            
            axes[idx].set_title(config['title'], fontweight='bold', pad=10)
            axes[idx].set_ylabel(config['ylabel'])
            axes[idx].set_xlabel('')
            axes[idx].legend(title='FTRL Status', frameon=True, loc='best')
            axes[idx].grid(True, alpha=0.3, linestyle='--', axis='y')
        
        # Overall title
        fig.suptitle(
            f'FTRL Component Effect: {problem_type} (n={actual_size})',
            fontsize=12, fontweight='bold', y=0.98
        )
        
        plt.tight_layout()
        
        # Save grouped plot
        plot_file = os.path.join(
            output_dir,
            f'ablation2_ftrl_{problem_type}_{problem_size}_grouped_{timestamp}.png'
        )
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
        print(f"Grouped plot saved to: {plot_file}")
        plt.close()
    
    else:  # separate mode
        # Individual figures for each metric
        if figure_size is None:
            figure_size = (5, 4)
        
        for metric in metrics_to_plot:
            config = metric_config[metric]
            
            fig, ax = plt.subplots(1, 1, figsize=figure_size)
            
            # Box plot
            sns.boxplot(
                data=df, x='Algorithm', y=config['column'], hue='FTRL',
                ax=ax, palette=palette, linewidth=1.2, width=0.6
            )
            
            # Strip plot for individual points
            sns.stripplot(
                data=df, x='Algorithm', y=config['column'], hue='FTRL',
                ax=ax, dodge=True, alpha=0.5, size=3,
                palette=palette, legend=False
            )
            
            ax.set_title(
                f'{config["title"]}\n{problem_type} (n={actual_size})',
                fontweight='bold', pad=10
            )
            ax.set_ylabel(config['ylabel'])
            ax.set_xlabel('')
            ax.legend(title='FTRL Status', frameon=True, loc='best')
            ax.grid(True, alpha=0.3, linestyle='--', axis='y')
            
            plt.tight_layout()
            
            # Save individual plot
            plot_file = os.path.join(
                output_dir,
                f'ablation2_ftrl_{problem_type}_{problem_size}_{metric}_{timestamp}.png'
            )
            plt.savefig(plot_file, dpi=300, bbox_inches='tight')
            plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
            print(f"{metric.capitalize()} plot saved to: {plot_file}")
            plt.close()
    
    # Create statistical comparison
    print("\n" + "="*80)
    print("STATISTICAL COMPARISON")
    print("="*80)
    
    for algorithm in ['UCB-Hedge', 'Thompson-Hedge']:
        print(f"\n{algorithm}:")
        with_ftrl = df[(df['Algorithm'] == algorithm) & (df['FTRL'] == 'With FTRL')]
        without_ftrl = df[(df['Algorithm'] == algorithm) & (df['FTRL'] == 'Without FTRL')]
        
        print(f"  Hypervolume:")
        print(f"    With FTRL:    {with_ftrl['Hypervolume'].mean():.4f} ± {with_ftrl['Hypervolume'].std():.4f}")
        print(f"    Without FTRL: {without_ftrl['Hypervolume'].mean():.4f} ± {without_ftrl['Hypervolume'].std():.4f}")
        hv_improvement = ((with_ftrl['Hypervolume'].mean() - without_ftrl['Hypervolume'].mean()) / 
                          without_ftrl['Hypervolume'].mean() * 100)
        print(f"    Improvement:  {hv_improvement:.2f}%")
        
        print(f"  Runtime (s):")
        print(f"    With FTRL:    {with_ftrl['Runtime (s)'].mean():.2f} ± {with_ftrl['Runtime (s)'].std():.2f}")
        print(f"    Without FTRL: {without_ftrl['Runtime (s)'].mean():.2f} ± {without_ftrl['Runtime (s)'].std():.2f}")
        
        print(f"  Solutions:")
        print(f"    With FTRL:    {with_ftrl['Solutions'].mean():.1f} ± {with_ftrl['Solutions'].std():.1f}")
        print(f"    Without FTRL: {without_ftrl['Solutions'].mean():.1f} ± {without_ftrl['Solutions'].std():.1f}")
    
    # Statistical summary table
    summary_stats = df.groupby(['Algorithm', 'FTRL']).agg({
        'Hypervolume': ['mean', 'std', 'min', 'max'],
        'Runtime (s)': ['mean', 'std', 'min', 'max'],
        'Solutions': ['mean', 'std', 'min', 'max']
    }).round(4)
    
    summary_file = os.path.join(
        output_dir,
        f'ablation2_summary_{problem_type}_{problem_size}_{timestamp}.csv'
    )
    summary_stats.to_csv(summary_file)
    print(f"\nSummary statistics saved to: {summary_file}")


def save_pareto_front_data(
    yaml_file: str,
    output_dir: str = 'pareto_data',
    save_format: str = 'both'  # 'pickle', 'json', or 'both'
):
    """
    Extract and save Pareto front data from ablation results
    
    Parameters:
    -----------
    yaml_file : str
        Path to YAML file containing ablation results
    output_dir : str
        Directory to save Pareto front data
    save_format : str
        'pickle' for binary, 'json' for text, 'both' for both formats
    
    Returns:
    --------
    pareto_data : dict
        Dictionary containing Pareto fronts for each algorithm configuration
    """
    
    print("\n" + "="*80)
    print("EXTRACTING PARETO FRONT DATA")
    print("="*80)
    
    # Load YAML data
    with open(yaml_file, 'r') as f:
        data = yaml.safe_load(f)
    
    problem_type = data.get('problem_type', 'Unknown')
    problem_size = data.get('problem_size', 'Unknown')
    actual_size = data.get('actual_size', 'Unknown')
    
    print(f"Problem: {problem_type}{actual_size} ({problem_size})")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Dictionary to store Pareto data
    pareto_data = {
        'metadata': {
            'problem_type': problem_type,
            'problem_size': problem_size,
            'actual_size': actual_size,
            'num_runs': data.get('num_runs', 0)
        },
        'configurations': {}
    }
    
    # Process each configuration
    for config_name, runs in data['results'].items():
        # Parse config name
        if 'UCB' in config_name:
            algorithm = 'UCB-Hedge'
        else:
            algorithm = 'Thompson-Hedge'
        
        if 'with_FTRL' in config_name:
            ftrl_status = 'With FTRL'
        else:
            ftrl_status = 'Without FTRL'
        
        full_config = f"{algorithm} {ftrl_status}"
        
        print(f"\nProcessing: {full_config}")
        
        # Extract data points
        hypervolumes = []
        runtimes = []
        num_solutions = []
        pareto_fronts = []  # NEW: Store actual Pareto fronts

        for run in runs:
            hypervolumes.append(run['hypervolume'])
            runtimes.append(run['runtime'])
            num_solutions.append(run['num_solutions'])
            # Extract objectives if available
            if 'objectives' in run:
                pareto_fronts.append(run['objectives'])

        # Store configuration data
        pareto_data['configurations'][full_config] = {
            'algorithm': algorithm,
            'ftrl_status': ftrl_status,
            'hypervolume': hypervolumes,
            'runtime': runtimes,
            'num_solutions': num_solutions,
            'statistics': {
                'hypervolume': {
                    'mean': float(np.mean(hypervolumes)),
                    'std': float(np.std(hypervolumes)),
                    'min': float(np.min(hypervolumes)),
                    'max': float(np.max(hypervolumes))
                },
                'runtime': {
                    'mean': float(np.mean(runtimes)),
                    'std': float(np.std(runtimes)),
                    'min': float(np.min(runtimes)),
                    'max': float(np.max(runtimes))
                },
                'num_solutions': {
                    'mean': float(np.mean(num_solutions)),
                    'std': float(np.std(num_solutions)),
                    'min': float(np.min(num_solutions)),
                    'max': float(np.max(num_solutions))
                }
            }
        }

        # NEW: Store Pareto fronts separately (not in configurations)
        if len(pareto_fronts) > 0:
            if 'pareto_fronts' not in pareto_data:
                pareto_data['pareto_fronts'] = {}
            pareto_data['pareto_fronts'][full_config] = pareto_fronts
            print(f"  Saved {len(pareto_fronts)} Pareto fronts")

        print(f"  Runs: {len(hypervolumes)}")
        print(f"  HV: {np.mean(hypervolumes):.4f} ± {np.std(hypervolumes):.4f}")
        print(f"  Runtime: {np.mean(runtimes):.2f} ± {np.std(runtimes):.2f}s")
    
    # Save data
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    base_filename = f'pareto_data_{problem_type}_{problem_size}_{timestamp}'
    
    if save_format in ['pickle', 'both']:
        pickle_file = os.path.join(output_dir, f'{base_filename}.pkl')
        with open(pickle_file, 'wb') as f:
            pickle.dump(pareto_data, f)
        print(f"\n✓ Saved pickle file: {pickle_file}")
    
    if save_format in ['json', 'both']:
        json_file = os.path.join(output_dir, f'{base_filename}.json')
        with open(json_file, 'w') as f:
            json.dump(pareto_data, f, indent=2)
        print(f"✓ Saved JSON file: {json_file}")
    
    return pareto_data


# ============================================================================
# MAIN RUNNER
# ============================================================================


# ============================================================================
# MAIN RUNNER WITH PARETO FRONT SAVING
# ============================================================================

def run_all_ablations(
    problem_type: str = 'BiTSP',
    problem_size: str = 'medium',
    num_runs: int = 5,
    output_dir: str = 'ablation_results',
    metrics_to_plot: List[str] = ['hypervolume', 'runtime', 'solutions'],
    plot_mode: str = 'grouped',
    save_pareto_fronts: bool = True,  # NEW: Enable Pareto front saving
):
    """
    Run all ablation studies
    
    Parameters:
    -----------
    problem_type : str
        One of 'BiTSP', 'BiKP', 'TriTSP'
    problem_size : str
        One of 'small', 'medium', 'large'
    num_runs : int
        Number of runs per configuration
    output_dir : str
        Directory to save results
    metrics_to_plot : List[str]
        Metrics to plot. Options: 'hypervolume', 'runtime', 'solutions'
    plot_mode : str
        'grouped' or 'separate'
    save_pareto_fronts : bool
        Whether to save Pareto front data
    """
    
    print("\n" + "="*80)
    print("RUNNING ALL ABLATION STUDIES")
    print("="*80)
    print(f"Problem: {problem_type} ({problem_size})")
    print(f"Runs per configuration: {num_runs}")
    print(f"Output directory: {output_dir}")
    print(f"Save Pareto fronts: {save_pareto_fronts}")
    print(f"Metrics: {metrics_to_plot}")
    print(f"Plot mode: {plot_mode}")
    
    overall_start = time.time()
    
    # Ablation 1: Decomposition Size
    print("\n" + "="*80)
    print("Starting Ablation Study 1: Decomposition Size")
    print("="*80)
    results_1 = ablation_decomposition_size(
        problem_type=problem_type,
        problem_size=problem_size,
        num_runs=num_runs,
        output_dir=output_dir,
        metrics_to_plot=metrics_to_plot,
        plot_mode=plot_mode
    )
    
    # Ablation 2: FTRL Toggle
    print("\n" + "="*80)
    print("Starting Ablation Study 2: FTRL Component")
    print("="*80)
    results_2 = ablation_ftrl_toggle(
        problem_type=problem_type,
        problem_size=problem_size,
        num_runs=num_runs,
        output_dir=output_dir,
        metrics_to_plot=metrics_to_plot,
        plot_mode=plot_mode
    )
    
    overall_time = time.time() - overall_start
    
    print("\n" + "="*80)
    print("ALL ABLATION STUDIES COMPLETED")
    print("="*80)
    print(f"Total time: {overall_time/60:.2f} minutes")
    print(f"Results saved to: {output_dir}/")
    
    # NEW: Process Pareto fronts if requested
    if save_pareto_fronts:
        print("\n" + "="*80)
        print("PROCESSING PARETO FRONTS")
        print("="*80)
        
        # Find the FTRL ablation YAML file (the one we want Pareto fronts for)
        yaml_files = glob.glob(os.path.join(output_dir, f'ablation2_ftrl_{problem_type}_{problem_size}_*.yaml'))
        
        if yaml_files:
            # Use the most recent file
            ftrl_yaml_file = max(yaml_files, key=os.path.getctime)
            print(f"Processing Pareto fronts from: {ftrl_yaml_file}")
            
            # Create Pareto subdirectory
            pareto_dir = os.path.join(output_dir, 'pareto_fronts')
            
            # Save Pareto front data
            pareto_data = save_pareto_front_data(
                yaml_file=ftrl_yaml_file,
                output_dir=pareto_dir,
                save_format='both'
            )
            
            print(f"\n✓ Pareto fronts saved to: {pareto_dir}/")
        else:
            print(f"⚠ Warning: No FTRL ablation YAML files found in {output_dir}")
    
    return results_1, results_2


# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    
    # Example 1: Decomposition size with ALL metrics, GROUPED
    # results = ablation_decomposition_size(
    #     problem_type='BiTSP',
    #     problem_size='medium',
    #     decomposition_sizes=[3,5,10,15,25,30,40,45],
    #     num_runs=50,
    #     metrics_to_plot=['hypervolume', 'runtime', 'solutions'],
    #     plot_mode='separate',
    #     output_dir='ablation_results'
    # )
    
    # Example 2: Decomposition size with ONLY hypervolume, SEPARATE files
    # results = ablation_decomposition_size(
    #     problem_type='BiKP',
    #     problem_size='medium',
    #     num_runs=5,
    #     metrics_to_plot=['hypervolume'],  # Only HV
    #     plot_mode='separate',
    #     output_dir='ablation_results'
    # )
    
    # Example 3: FTRL ablation with 2 metrics, GROUPED
    results = ablation_ftrl_toggle(
        problem_type='BiKP',
        problem_size='medium',
        num_runs=10,
        metrics_to_plot=['hypervolume', 'runtime'],  # Only 2 metrics
        plot_mode='grouped',
        save_pareto_fronts=True,  # This will save Pareto data
        output_dir='ablation_results_experts'
    )
    
    # Example 4: Run all with custom settings
    # results_1, results_2 = run_all_ablations(
    #     problem_type='TriTSP',
    #     problem_size='medium',
    #     num_runs=5,
    #     metrics_to_plot=['hypervolume', 'solutions'],  # Skip runtime
    #     plot_mode='separate',  # Individual files
    #     output_dir='ablation_results'
    # )