"""
Flexible Hyperparameter Sweep Study for MOCO Algorithms with Multi-Scale Support

This script allows you to study the effect of any hyperparameter on UCB-Hedge,
Thompson-Hedge, or both algorithms simultaneously across multiple problem scales.

Features:
- Sweep any hyperparameter with custom value ranges
- Run experiments across multiple problem scales (small, medium, large) simultaneously
- Choose which algorithm(s) to test (UCB, Thompson, or both)
- Track multiple metrics: hypervolume, runtime, solutions, tour length
- ICML-style publication-ready plots with multi-scale comparisons
- Statistical analysis and comparison
"""

import os
import yaml
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any, List, Tuple, Optional, Union
from collections import defaultdict
import warnings
import pickle
import json

warnings.filterwarnings('ignore')

import sys
from pathlib import Path

# Add project root to path (works from any subdirectory)
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

# Import MOCO components
from MOCO.problems import BiObjectiveTSP, MultiObjectiveKnapsack, TriObjectiveTSP
from MOCO.evaluation import MOCOEvaluator

# Import algorithm wrappers
from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import (
    CachedAdvancedBiKPWrapper as UCBWrapper
)
from our_method_dl_Thompson_variant import (
    CachedAdvancedBiKPWrapper as ThompsonWrapper
)

# ============================================================================
# ICML STYLE CONFIGURATION
# ============================================================================

def setup_icml_style():
    """Configure matplotlib for ICML paper style"""
    
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("husl")
    
    plt.rcParams.update({
        # Font settings
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica', 'Liberation Sans'],
        'font.size': 9,
        'axes.labelsize': 10,
        'axes.titlesize': 11,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        
        # Figure settings
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        
        # Layout
        'figure.constrained_layout.use': True,
        'figure.autolayout': False,
        
        # Grid and spines
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 0.8,
        
        # Lines and markers
        'lines.linewidth': 1.5,
        'lines.markersize': 6,
        
        # Legend
        'legend.frameon': True,
        'legend.framealpha': 0.9,
        'legend.edgecolor': '0.8',
        
        # Other
        'axes.axisbelow': True,
    })

setup_icml_style()

# ============================================================================
# PROBLEM CONFIGURATIONS
# ============================================================================

PROBLEM_REGISTRY = {
    'BiTSP': {
        'class': BiObjectiveTSP,
        'ref_type': 'BiTSP',
        'sizes': {
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100},
        },
        'default_ref': lambda size: (20, 20) if size == 20 else (35, 35) if size == 50 else (65, 65),
        'has_tour_length': True
    },
    'BiKP': {
        'class': MultiObjectiveKnapsack,
        'ref_type': 'BiKP',
        'sizes': {
            'small': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5},
            'medium': {'n_items': 100, 'n_objectives': 2, 'capacity': 25.0},
            'large': {'n_items': 200, 'n_objectives': 2, 'capacity': 50.0},
        },
        'default_ref': lambda size: (5, 5) if size == 50 else (20, 20) if size == 100 else (30, 30),
        'has_tour_length': False
    },
    'TriTSP': {
        'class': TriObjectiveTSP,
        'ref_type': 'TriTSP',
        'sizes': {
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100},
        },
        'default_ref': lambda size: (20, 20, 20) if size == 20 else (35, 35, 35) if size == 50 else (65, 65, 65),
        'has_tour_length': True
    }
}

# ============================================================================
# DEFAULT HYPERPARAMETERS
# ============================================================================

def get_default_params(problem_size: int) -> Dict[str, Any]:
    """Get default hyperparameters for the algorithms"""
    return {
        'learning_rate': 0.5,
        'initial_temperature': 1.0,
        'temp_decay': 0.98,
        'hybrid_ratio': 0.5,
        'adaptive_hybrid': True,
        'decomposition_size': 35, # 15
        'overlap': 15, #11,
        'max_iterations': 100, # 200
        'nb_rounds': 15,
        'patience': 50, # 50
        'use_lagrangian': True,
        'use_ftrl': True,
        'dual_step_size': 1.0,
        'use_accelerated_dual': True,
        'use_diminishing_overlap': True,
        'overlap_decay_rate': 0.1,
        'n_weight_vectors': 25, # 30,
        'ucb_coefficient': 3.0,  # Only for UCB
    }

# ============================================================================
# HYPERPARAMETER SWEEP STUDY WITH MULTI-SCALE SUPPORT
# ============================================================================

def hyperparameter_sweep_multiscale(
    hyperparam_name: str,
    hyperparam_values: List[Union[float, int, bool]],
    problem_type: str = 'BiTSP',
    problem_scales: List[str] = None,  # None means all scales
    algorithms: List[str] = ['UCB', 'Thompson'],
    num_runs: int = 5,
    output_dir: str = 'hyperparam_sweep_multiscale_results',
    metrics_to_plot: List[str] = ['hypervolume', 'runtime', 'solutions', 'tour_length'],
    plot_mode: str = 'grouped',
    figure_size: Tuple[float, float] = None,
    custom_base_params: Dict[str, Any] = None
):
    """
    Sweep a single hyperparameter across different values for selected algorithms
    and multiple problem scales
    
    Parameters:
    -----------
    hyperparam_name : str
        Name of the hyperparameter to sweep
    hyperparam_values : List
        List of values to test for the hyperparameter
    problem_type : str
        One of 'BiTSP', 'BiKP', 'TriTSP'
    problem_scales : List[str]
        List of scales to test ['small', 'medium', 'large'] or None for all
    algorithms : List[str]
        List containing 'UCB' and/or 'Thompson'
    num_runs : int
        Number of runs per configuration
    output_dir : str
        Directory to save results
    metrics_to_plot : List[str]
        Metrics to plot. Options: 'hypervolume', 'runtime', 'solutions', 'tour_length'
    plot_mode : str
        'grouped' for single figure with subplots, 'separate' for individual files
    figure_size : Tuple[float, float]
        Custom figure size (width, height)
    custom_base_params : Dict[str, Any]
        Optional custom base parameters
    
    Returns:
    --------
    results : dict
        Dictionary containing all experimental results organized by scale
    """
    
    print("\n" + "="*80)
    print(f"MULTI-SCALE HYPERPARAMETER SWEEP: {hyperparam_name}")
    print("="*80)
    
    # Validate inputs
    valid_algorithms = ['UCB', 'Thompson']
    algorithms = [alg for alg in algorithms if alg in valid_algorithms]
    if not algorithms:
        raise ValueError(f"No valid algorithms selected. Choose from: {valid_algorithms}")
    
    if problem_type not in PROBLEM_REGISTRY:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    # Get problem configuration
    problem_config = PROBLEM_REGISTRY[problem_type]
    problem_class = problem_config['class']
    has_tour_length = problem_config['has_tour_length']
    
    # Determine scales to test
    if problem_scales is None:
        problem_scales = ['small', 'medium', 'large']
    else:
        # Validate scales
        valid_scales = ['small', 'medium', 'large']
        problem_scales = [s for s in problem_scales if s in valid_scales]
        if not problem_scales:
            raise ValueError(f"No valid scales selected. Choose from: {valid_scales}")
    
    # Validate metrics
    valid_metrics = ['hypervolume', 'runtime', 'solutions']
    if has_tour_length:
        valid_metrics.append('tour_length')
    metrics_to_plot = [m for m in metrics_to_plot if m in valid_metrics]
    if not metrics_to_plot:
        raise ValueError(f"No valid metrics selected. Choose from: {valid_metrics}")
    
    # Remove tour_length if problem doesn't support it
    if 'tour_length' in metrics_to_plot and not has_tour_length:
        print(f"\nWarning: 'tour_length' metric not available for {problem_type}. Removing from metrics.")
        metrics_to_plot.remove('tour_length')
    
    print(f"\nProblem: {problem_type}")
    print(f"Scales to test: {problem_scales}")
    print(f"Hyperparameter: {hyperparam_name}")
    print(f"Values to test: {hyperparam_values}")
    print(f"Algorithms: {algorithms}")
    print(f"Runs per configuration: {num_runs}")
    print(f"Metrics to plot: {metrics_to_plot}")
    print(f"Plot mode: {plot_mode}")
    
    # Storage for results - now includes scale dimension
    results = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    
    # Test each scale
    for scale in problem_scales:
        print(f"\n{'='*60}")
        print(f"TESTING SCALE: {scale.upper()}")
        print(f"{'='*60}")
        
        # Get problem parameters for this scale
        problem_params = problem_config['sizes'][scale]
        
        # Determine problem size parameter
        if 'n_cities' in problem_params:
            actual_size = problem_params['n_cities']
        elif 'n_items' in problem_params:
            actual_size = problem_params['n_items']
        else:
            actual_size = 50
        
        # Get reference point for this scale
        ref_point = problem_config['default_ref'](actual_size)
        print(f"Problem size: n={actual_size}, Reference point: {ref_point}")
        
        # Get base parameters for this scale
        if custom_base_params is None:
            base_params = get_default_params(actual_size)
        else:
            base_params = get_default_params(actual_size)
            base_params.update(custom_base_params)
        
        # Test each hyperparameter value
        for hp_value in hyperparam_values:
            print(f"\n{'-'*40}")
            print(f"Testing {hyperparam_name} = {hp_value} for {scale}")
            print(f"{'-'*40}")
            
            # Test each algorithm
            for algorithm in algorithms:
                print(f"\n  Testing {algorithm} on {scale}...")
                
                # Set parameters
                params = base_params.copy()
                params[hyperparam_name] = hp_value
                
                # Select appropriate wrapper
                if algorithm == 'UCB':
                    wrapper_class = UCBWrapper
                    alg_name = 'UCB-Hedge'
                else:
                    wrapper_class = ThompsonWrapper
                    alg_name = 'Thompson-Hedge'
                    # Remove ucb_coefficient for Thompson
                    if 'ucb_coefficient' in params and hyperparam_name != 'ucb_coefficient':
                        params.pop('ucb_coefficient')
                
                # Run experiments
                for run in range(num_runs):
                    try:
                        evaluator = MOCOEvaluator(reference_point=ref_point)
                        result = evaluator.evaluate_algorithm(
                            algorithm_class=wrapper_class,
                            problem_class=problem_class,
                            algorithm_name=f"{alg_name}_{hyperparam_name}{hp_value}_{scale}",
                            parameters=params,
                            problem_params=problem_params,
                            num_runs=1
                        )
                        
                        # Extract tour length if available
                        tour_length = None
                        if has_tour_length:
                            try:
                                if hasattr(result, 'objectives') and result.objectives:
                                    tour_lengths = []
                                    for obj in result.objectives:
                                        if isinstance(obj, (list, tuple, np.ndarray)):
                                            tour_lengths.append(np.mean(obj))
                                        elif isinstance(obj, (int, float)):
                                            tour_lengths.append(obj)
                                    
                                    if tour_lengths:
                                        tour_length = np.mean(tour_lengths)
                                
                                elif hasattr(result, 'pareto_front') and result.pareto_front:
                                    tour_lengths = []
                                    for solution in result.pareto_front:
                                        if hasattr(solution, 'objectives'):
                                            tour_lengths.append(np.mean(solution.objectives))
                                        elif isinstance(solution, (list, tuple, np.ndarray)):
                                            tour_lengths.append(np.mean(solution))
                                    
                                    if tour_lengths:
                                        tour_length = np.mean(tour_lengths)
                            
                            except Exception as e:
                                print(f"      Warning: Could not extract tour length: {e}")
                        
                        run_data = {
                            'hypervolume': result.hypervolume,
                            'runtime': result.runtime,
                            'num_solutions': result.num_nondominated,
                            'tour_length': tour_length
                        }
                        
                        # Store with scale as additional key
                        results[alg_name][scale][hp_value].append(run_data)
                        
                        print(f"    Run {run+1}/{num_runs} - HV: {result.hypervolume:.4f}, "
                              f"Time: {result.runtime:.2f}s, Solutions: {result.num_nondominated}"
                              + (f", TourLen: {tour_length:.2f}" if tour_length else ""))
                    
                    except Exception as e:
                        print(f"    Run {run+1} failed: {e}")
                        import traceback
                        traceback.print_exc()
    
    # Save results
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert results to serializable format
    results_serializable = {
        alg: {
            scale: {str(hp_val): runs for hp_val, runs in hp_dict.items()}
            for scale, hp_dict in scale_dict.items()
        }
        for alg, scale_dict in results.items()
    }
    
    results_file = os.path.join(
        output_dir,
        f'sweep_multiscale_{hyperparam_name}_{problem_type}_{timestamp}.yaml'
    )
    # with open(results_file, 'w') as f:
    #     yaml.dump({
    #         'hyperparam_name': hyperparam_name,
    #         'hyperparam_values': [str(v) for v in hyperparam_values],
    #         'problem_type': problem_type,
    #         'problem_scales': problem_scales,
    #         'algorithms': algorithms,
    #         'num_runs': num_runs,
    #         'base_params': {k: str(v) for k, v in base_params.items()},
    #         'results': results_serializable
    #     }, f, default_flow_style=False)
    with open(results_file, 'w') as f:
        yaml.dump(to_python_types({
            'hyperparam_name': hyperparam_name,
            'hyperparam_values': hyperparam_values,  # no need for str() now
            'problem_type': problem_type,
            'problem_scales': problem_scales,
            'algorithms': algorithms,
            'num_runs': num_runs,
            'base_params': base_params,
            'results': results_serializable
        }), f, default_flow_style=False)
    
    print(f"\nResults saved to: {results_file}")
    
    # Also save as pickle for easier loading
    # pickle_file = results_file.replace('.yaml', '.pkl')
    # with open(pickle_file, 'wb') as f:
    #     pickle.dump(results, f)
    # print(f"Results also saved as pickle: {pickle_file}")
    
    # Generate plots
    plot_hyperparam_sweep_multiscale(
        results, hyperparam_name, hyperparam_values, problem_type,
        problem_scales, output_dir, timestamp,
        metrics_to_plot, plot_mode, figure_size
    )
    
    return dict(results)


# ============================================================================
# PLOTTING FUNCTIONS FOR MULTI-SCALE
# ============================================================================

def plot_hyperparam_sweep_multiscale(
    results, hyperparam_name, hyperparam_values, problem_type,
    problem_scales, output_dir, timestamp,
    metrics_to_plot, plot_mode='grouped', figure_size=None
):
    """Generate ICML-style plots for multi-scale hyperparameter sweep"""
    
    print("\nGenerating multi-scale plots...")
    
    # Prepare data
    data = []
    outliers = []
    
    for alg_name, scale_dict in results.items():
        for scale, hp_dict in scale_dict.items():
            for hp_value, runs in hp_dict.items():
                # Collect all values for outlier detection
                metric_values = {
                    'hypervolume': [],
                    'runtime': [],
                    'num_solutions': [],
                    'tour_length': []
                }
                
                for run_result in runs:
                    metric_values['hypervolume'].append(run_result['hypervolume'])
                    metric_values['runtime'].append(run_result['runtime'])
                    metric_values['num_solutions'].append(run_result['num_solutions'])
                    if run_result['tour_length'] is not None:
                        metric_values['tour_length'].append(run_result['tour_length'])
                
                # Detect outliers using IQR method
                def detect_outliers_iqr(values, threshold=2.5):
                    if len(values) < 4:
                        return set()
                    
                    q1 = np.percentile(values, 25)
                    q3 = np.percentile(values, 75)
                    iqr = q3 - q1
                    lower = q1 - threshold * iqr
                    upper = q3 + threshold * iqr
                    
                    return set(i for i, v in enumerate(values) if v < lower or v > upper)
                
                # Find outliers
                outlier_indices = set()
                for metric, values in metric_values.items():
                    if values:
                        outlier_indices.update(detect_outliers_iqr(values))
                
                # Add data points
                for idx, run_result in enumerate(runs):
                    data_point = {
                        'Algorithm': alg_name,
                        'Scale': scale,
                        hyperparam_name: hp_value,
                        'Hypervolume': run_result['hypervolume'],
                        'Runtime (s)': run_result['runtime'],
                        'Solutions': run_result['num_solutions'],
                        'is_outlier': idx in outlier_indices
                    }
                    if run_result['tour_length'] is not None:
                        data_point['Tour Length'] = run_result['tour_length']
                    
                    if idx in outlier_indices:
                        outliers.append(data_point)
                    else:
                        data.append(data_point)
    
    df = pd.DataFrame(data)
    df_outliers = pd.DataFrame(outliers) if outliers else None
    
    if df_outliers is not None and len(df_outliers) > 0:
        print(f"⚠️  Detected {len(df_outliers)} outlier runs")
    
    print(f"Data collected: {len(data)} runs across {len(df['Algorithm'].unique())} algorithm(s) and {len(df['Scale'].unique())} scale(s)")
    
    # Define metric configurations
    metric_config = {
        'hypervolume': {
            'column': 'Hypervolume',
            'ylabel': 'Hypervolume',
            'title': f'Hypervolume vs {hyperparam_name}',
            'marker': 'o'
        },
        'runtime': {
            'column': 'Runtime (s)',
            'ylabel': 'Runtime (seconds)',
            'title': f'Runtime vs {hyperparam_name}',
            'marker': 's'
        },
        'solutions': {
            'column': 'Solutions',
            'ylabel': 'Number of Solutions',
            'title': f'Pareto Front Size vs {hyperparam_name}',
            'marker': '^'
        },
        'tour_length': {
            'column': 'Tour Length',
            'ylabel': 'Average Tour Length',
            'title': f'Tour Length vs {hyperparam_name}',
            'marker': 'd'
        }
    }
    
    # Filter to available metrics
    available_metrics = {}
    for k, v in metric_config.items():
        if k not in metrics_to_plot:
            continue
        if v['column'] not in df.columns:
            continue
        if not df[v['column']].notna().any():
            continue
        available_metrics[k] = v
    
    metric_config = available_metrics
    n_metrics = len(metric_config)
    
    if n_metrics == 0:
        print("\n⚠️ Warning: No valid metrics with data to plot!")
        return
    
    print(f"Plotting {n_metrics} metric(s): {list(metric_config.keys())}")
    
    # Color palette for different scales
    scale_colors = {
        'small': '#2ecc71',   # Green
        'medium': '#3498db',  # Blue
        'large': '#e74c3c'    # Red
    }
    
    # Line styles for different algorithms
    alg_styles = {
        'UCB-Hedge': '-',
        'Thompson-Hedge': '--'
    }
    
    if plot_mode == 'grouped':
        # Create plots
        for variant in ['clean', 'with_outliers']:
            if figure_size is None:
                figure_size = (5 * n_metrics, 4)
            
            fig, axes = plt.subplots(1, n_metrics, figsize=figure_size)
            if n_metrics == 1:
                axes = [axes]
            
            show_outliers = (variant == 'with_outliers')
            
            for idx, (metric, config) in enumerate(metric_config.items()):
                # Plot main data
                for alg in df['Algorithm'].unique():
                    for scale in sorted(df['Scale'].unique(), 
                                      key=lambda x: ['small', 'medium', 'large'].index(x)):
                        # Filter data
                        mask = (df['Algorithm'] == alg) & (df['Scale'] == scale)
                        scale_data = df[mask]
                        
                        if len(scale_data) == 0:
                            continue
                        
                        # Calculate statistics
                        grouped = scale_data.groupby(hyperparam_name)[config['column']]
                        summary = grouped.agg([
                            'median',
                            lambda x: np.percentile(x, 25),
                            lambda x: np.percentile(x, 75)
                        ]).reset_index()
                        summary.columns = [hyperparam_name, 'median', 'q25', 'q75']
                        
                        # Plot line
                        label = f"{alg} ({scale})"
                        color = scale_colors[scale]
                        linestyle = alg_styles.get(alg, '-')
                        
                        axes[idx].plot(summary[hyperparam_name], summary['median'],
                                     marker=config['marker'], label=label,
                                     color=color, linestyle=linestyle,
                                     linewidth=2, markersize=6)
                        
                        # Add IQR band
                        axes[idx].fill_between(summary[hyperparam_name],
                                              summary['q25'], summary['q75'],
                                              alpha=0.15, color=color)
                
                # Overlay outliers
                if show_outliers and df_outliers is not None and len(df_outliers) > 0:
                    outlier_subset = df_outliers[df_outliers[config['column']].notna()]
                    if len(outlier_subset) > 0:
                        for scale in outlier_subset['Scale'].unique():
                            scale_outliers = outlier_subset[outlier_subset['Scale'] == scale]
                            axes[idx].scatter(scale_outliers[hyperparam_name],
                                           scale_outliers[config['column']],
                                           marker='x', s=60, c=scale_colors[scale],
                                           linewidths=2, alpha=0.7)
                
                axes[idx].set_title(config['title'], fontweight='bold', pad=10)
                axes[idx].set_ylabel(config['ylabel'])
                axes[idx].set_xlabel(hyperparam_name.replace('_', ' ').title())
                axes[idx].legend(frameon=True, loc='best', fontsize=7,
                                ncol=2 if len(df['Algorithm'].unique()) > 1 else 1)
                axes[idx].grid(True, alpha=0.3, linestyle='--')
            
            # Overall title
            title_text = f'Multi-Scale Hyperparameter Sweep: {problem_type}\n'
            if show_outliers:
                title_text += 'Lines: median, Bands: IQR, X marks: outliers'
            else:
                title_text += 'Lines: median, Bands: IQR (25th-75th percentile)'
            fig.suptitle(title_text, fontsize=11, fontweight='bold', y=1.02)
            
            plt.tight_layout()
            
            # Save plot
            plot_file = os.path.join(
                output_dir,
                f'sweep_multiscale_{hyperparam_name}_{problem_type}_{variant}_{timestamp}.png'
            )
            plt.savefig(plot_file, dpi=300, bbox_inches='tight')
            plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
            print(f"{variant.replace('_', ' ').capitalize()} plot saved to: {plot_file}")
            plt.close()
    
    else:  # separate mode
        # Individual figures for each metric
        if figure_size is None:
            figure_size = (6, 4.5)
        
        for metric, config in metric_config.items():
            fig, ax = plt.subplots(1, 1, figsize=figure_size)
            
            # Plot main data
            for alg in df['Algorithm'].unique():
                for scale in sorted(df['Scale'].unique(),
                                  key=lambda x: ['small', 'medium', 'large'].index(x)):
                    # Filter data
                    mask = (df['Algorithm'] == alg) & (df['Scale'] == scale)
                    scale_data = df[mask]
                    
                    if len(scale_data) == 0:
                        continue
                    
                    # Calculate statistics
                    grouped = scale_data.groupby(hyperparam_name)[config['column']]
                    summary = grouped.agg([
                        'median',
                        lambda x: np.percentile(x, 25),
                        lambda x: np.percentile(x, 75)
                    ]).reset_index()
                    summary.columns = [hyperparam_name, 'median', 'q25', 'q75']
                    
                    # Plot line
                    label = f"{alg} ({scale})"
                    color = scale_colors[scale]
                    linestyle = alg_styles.get(alg, '-')
                    
                    ax.plot(summary[hyperparam_name], summary['median'],
                           marker=config['marker'], label=label,
                           color=color, linestyle=linestyle,
                           linewidth=2, markersize=6)
                    
                    # Add IQR band
                    ax.fill_between(summary[hyperparam_name],
                                   summary['q25'], summary['q75'],
                                   alpha=0.15, color=color)
            
            # Overlay outliers
            if df_outliers is not None and len(df_outliers) > 0:
                outlier_subset = df_outliers[df_outliers[config['column']].notna()]
                if len(outlier_subset) > 0:
                    for scale in outlier_subset['Scale'].unique():
                        scale_outliers = outlier_subset[outlier_subset['Scale'] == scale]
                        ax.scatter(scale_outliers[hyperparam_name],
                                 scale_outliers[config['column']],
                                 marker='x', s=60, c=scale_colors[scale],
                                 linewidths=2, alpha=0.7)
            
            ax.set_title(
                f'{config["title"]}\nMulti-Scale {problem_type}',
                fontweight='bold', pad=10
            )
            ax.set_ylabel(config['ylabel'])
            ax.set_xlabel(hyperparam_name.replace('_', ' ').title())
            ax.legend(frameon=True, loc='best',
                     ncol=2 if len(df['Algorithm'].unique()) > 1 else 1)
            ax.grid(True, alpha=0.3, linestyle='--')
            
            plt.tight_layout()
            
            # Save individual plot
            plot_file = os.path.join(
                output_dir,
                f'sweep_multiscale_{hyperparam_name}_{problem_type}_{metric}_{timestamp}.png'
            )
            plt.savefig(plot_file, dpi=300, bbox_inches='tight')
            plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
            print(f"{metric.capitalize()} plot saved to: {plot_file}")
            plt.close()
    
    # Create statistical summary
    print("\n" + "="*80)
    print("STATISTICAL SUMMARY BY SCALE")
    print("="*80)
    
    # Group by algorithm, scale, and hyperparameter value
    summary_cols = ['Hypervolume', 'Runtime (s)', 'Solutions']
    if 'Tour Length' in df.columns:
        summary_cols.append('Tour Length')
    
    def iqr(x):
        return np.percentile(x, 75) - np.percentile(x, 25)
    
    summary_stats = df.groupby(['Algorithm', 'Scale', hyperparam_name])[summary_cols].agg([
        'median', iqr, 'mean', 'std'
    ]).round(4)
    
    print("\nMedian ± IQR by Scale:")
    print(summary_stats[[(col, 'median') for col in summary_cols] + 
                        [(col, 'iqr') for col in summary_cols]])
    
    # Save summary table
    summary_file = os.path.join(
        output_dir,
        f'sweep_multiscale_{hyperparam_name}_{problem_type}_summary_{timestamp}.csv'
    )
    summary_stats.to_csv(summary_file)
    print(f"\nSummary statistics saved to: {summary_file}")
    
    # Find best hyperparameter value for each algorithm and scale
    print("\n" + "="*80)
    print("BEST HYPERPARAMETER VALUES BY SCALE (by Hypervolume)")
    print("="*80)
    
    for alg in sorted(df['Algorithm'].unique()):
        print(f"\n{alg}:")
        for scale in sorted(df['Scale'].unique(),
                          key=lambda x: ['small', 'medium', 'large'].index(x)):
            mask = (df['Algorithm'] == alg) & (df['Scale'] == scale)
            scale_data = df[mask]
            
            if len(scale_data) > 0:
                best_idx = scale_data.groupby(hyperparam_name)['Hypervolume'].mean().idxmax()
                best_hv = scale_data[scale_data[hyperparam_name] == best_idx]['Hypervolume'].mean()
                best_std = scale_data[scale_data[hyperparam_name] == best_idx]['Hypervolume'].std()
                
                print(f"  {scale.capitalize()}:")
                print(f"    Best {hyperparam_name}: {best_idx}")
                print(f"    Hypervolume: {best_hv:.4f} ± {best_std:.4f}")


# ============================================================================
# MULTI-HYPERPARAMETER SWEEP WITH MULTI-SCALE
# ============================================================================

def multi_hyperparam_sweep_multiscale(
    hyperparam_configs: List[Dict[str, Any]],
    problem_type: str = 'BiTSP',
    problem_scales: List[str] = None,
    algorithms: List[str] = ['UCB', 'Thompson'],
    num_runs: int = 5,
    output_dir: str = 'multi_hyperparam_sweep_multiscale',
    metrics_to_plot: List[str] = ['hypervolume', 'runtime', 'solutions', 'tour_length'],
    plot_mode: str = 'grouped'
):
    """
    Run multiple hyperparameter sweeps sequentially with multi-scale support
    
    Parameters:
    -----------
    hyperparam_configs : List[Dict]
        List of configurations, each containing:
        - 'name': hyperparameter name
        - 'values': list of values to test
        - 'base_params': (optional) custom base parameters
    problem_type : str
        One of 'BiTSP', 'BiKP', 'TriTSP'
    problem_scales : List[str]
        List of scales to test or None for all
    algorithms : List[str]
        List containing 'UCB' and/or 'Thompson'
    num_runs : int
        Number of runs per configuration
    output_dir : str
        Directory to save results
    metrics_to_plot : List[str]
        Metrics to plot
    plot_mode : str
        'grouped' or 'separate'
    
    Returns:
    --------
    all_results : dict
        Dictionary containing all experimental results
    """
    
    print("\n" + "="*80)
    print("MULTI-HYPERPARAMETER SWEEP WITH MULTI-SCALE")
    print("="*80)
    print(f"Number of hyperparameters to sweep: {len(hyperparam_configs)}")
    print(f"Problem: {problem_type}")
    print(f"Scales: {problem_scales if problem_scales else 'all'}")
    print(f"Algorithms: {algorithms}")
    
    all_results = {}
    
    for i, config in enumerate(hyperparam_configs, 1):
        print(f"\n{'='*80}")
        print(f"SWEEP {i}/{len(hyperparam_configs)}: {config['name']}")
        print(f"{'='*80}")
        
        results = hyperparameter_sweep_multiscale(
            hyperparam_name=config['name'],
            hyperparam_values=config['values'],
            problem_type=problem_type,
            problem_scales=problem_scales,
            algorithms=algorithms,
            num_runs=num_runs,
            output_dir=output_dir,
            metrics_to_plot=metrics_to_plot,
            plot_mode=plot_mode,
            custom_base_params=config.get('base_params', None)
        )
        
        all_results[config['name']] = results
    
    print("\n" + "="*80)
    print("ALL MULTI-SCALE HYPERPARAMETER SWEEPS COMPLETED")
    print("="*80)
    print(f"Results saved to: {output_dir}/")
    
    return all_results


def to_python_types(obj):
    """Convert numpy types to native Python types for YAML serialization"""
    if isinstance(obj, dict):
        return {k: to_python_types(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [to_python_types(v) for v in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj

# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    
    # Example 1: Single hyperparameter sweep across all scales
    # results = hyperparameter_sweep_multiscale(
    #     hyperparam_name='learning_rate',
    #     hyperparam_values=[0.1, 0.3, 0.5, 0.7, 0.9],
    #     problem_type='BiTSP',
    #     problem_scales=['large'], # None,  # None means all scales (small, medium, large)
    #     algorithms=['UCB'], # 'Thompson'
    #     num_runs=50,
    #     metrics_to_plot=['hypervolume', 'runtime', 'solutions', 'tour_length'],
    #     plot_mode='grouped',
    #     output_dir='hyperparam_sweep_multiscale_results'
    # )
    
    results = hyperparameter_sweep_multiscale(
        hyperparam_name='dual_step_size',
        hyperparam_values=[0.1, 0.5, 1.0, 1.5, 2.0],
        problem_type='BiTSP',
        problem_scales=['large'], # None,  # None means all scales (small, medium, large)
        algorithms=['UCB'], # 'Thompson'
        num_runs=50,
        metrics_to_plot=['hypervolume', 'runtime', 'solutions', 'tour_length'],
        plot_mode='grouped',
        output_dir='hyperparam_sweep_multiscale_results'
    )

    # results = hyperparameter_sweep_multiscale(
    #     hyperparam_name='overlap_decay_rate',
    #     hyperparam_values=[0.05, 0.1, 0.15, 0.2],
    #     problem_type='BiTSP',
    #     problem_scales=['large'], # None,  # None means all scales (small, medium, large)
    #     algorithms=['UCB'], # 'Thompson'
    #     num_runs=50,
    #     metrics_to_plot=['hypervolume', 'runtime', 'solutions', 'tour_length'],
    #     plot_mode='grouped',
    #     output_dir='hyperparam_sweep_multiscale_results'
    # )

    # Example 2: Single hyperparameter sweep for specific scales
    # results = hyperparameter_sweep_multiscale(
    #     hyperparam_name='dual_step_size',
    #     hyperparam_values=[0.5, 1.0, 1.5, 2.0, 2.5],
    #     problem_type='BiTSP',
    #     problem_scales=['small', 'large'],  # Only small and large
    #     algorithms=['UCB'],
    #     num_runs=10,
    #     metrics_to_plot=['hypervolume', 'runtime'],
    #     plot_mode='separate',
    #     output_dir='hyperparam_sweep_multiscale_results'
    # )
    
    # Example 3: Multiple hyperparameter sweeps across all scales
    # hyperparam_configs = [
    #     {
    #         'name': 'dual_step_size',
    #         'values': [0.5, 1.0, 1.5, 2.0, 2.5]
    #     },
    #     {
    #         'name': 'overlap_decay_rate',
    #         'values': [0.05, 0.1, 0.15, 0.2, 0.25]
    #     },
    #     {
    #         'name': 'learning_rate',
    #         'values': [0.1, 0.3, 0.5, 0.7, 0.9]
    #     }
    # ]
    # 
    # all_results = multi_hyperparam_sweep_multiscale(
    #     hyperparam_configs=hyperparam_configs,
    #     problem_type='BiTSP',
    #     problem_scales=None,  # All scales
    #     algorithms=['UCB', 'Thompson'],
    #     num_runs=10,
    #     metrics_to_plot=['hypervolume', 'runtime', 'tour_length'],
    #     plot_mode='grouped',
    #     output_dir='multi_hyperparam_sweep_multiscale'
    # )

    # my slides contain this code (older function name version)
    # dual_step_size: [0.1, 0.5, 1.0, 1.5, 2.0]
    # overlap_decay_rate: [0.05, 0.1, 0.15, 0.2]
    # learning_rate: [0.1, 0.3, 0.5, 0.7]
    # Example 4: Test learning_rate for Thompson only on small problem
    # results = hyperparameter_sweep(
    #     hyperparam_name='learning_rate',
    #     hyperparam_values=[0.1, 0.3, 0.5, 0.7, 0.9],
    #     problem_type='BiTSP',
    #     problem_size='medium',
    #     algorithms=['Thompson'],  # Thompson only
    #     num_runs=50,
    #     metrics_to_plot=['hypervolume', 'runtime', 'solutions', 'tour_length'],
    #     plot_mode='grouped',
    #     output_dir='hyperparam_sweep_results'
    # )
