import os
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from matplotlib.patches import Ellipse
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# SETUP PLOTTING STYLE
# ============================================================================

def setup_plot_style():
    """Setup clean ICML-style plotting with pastel colors"""
    plt.style.use('seaborn-v0_8-paper')
    
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica'],
        'font.size': 10,
        'axes.labelsize': 11,
        'axes.titlesize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 13,
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 1.0,
        'lines.linewidth': 2.0,
        'lines.markersize': 8,
    })

setup_plot_style()

# Pastel color palette
PASTEL_COLORS = {
    'UCB With FTRL': '#A8DADC',      # Light blue
    'UCB Without FTRL': '#F4A261',    # Light orange
    'Thompson With FTRL': '#E9C46A',  # Light yellow
    'Thompson Without FTRL': '#E76F51' # Light coral
}

# ============================================================================
# DATA LOADING FUNCTIONS
# ============================================================================

def load_ftrl_yaml(yaml_file):
    """Load FTRL ablation YAML file"""
    with open(yaml_file, 'r') as f:
        data = yaml.safe_load(f)
    return data


def yaml_to_dataframe(data, algorithms_to_plot=None):
    """
    Convert YAML data to pandas DataFrame with optional filtering
    
    Parameters:
    -----------
    data : dict
        YAML data dictionary
    algorithms_to_plot : list, optional
        List of algorithm configurations to include. Options:
        - 'UCB-Hedge With FTRL'
        - 'UCB-Hedge Without FTRL'
        - 'Thompson-Hedge With FTRL'
        - 'Thompson-Hedge Without FTRL'
    """
    rows = []
    
    for config_name, runs in data['results'].items():
        # Parse config name
        if 'UCB' in config_name:
            algorithm = 'UCB-Hedge'
        else:
            algorithm = 'Thompson-Hedge'
        
        if 'with_FTRL' in config_name:
            ftrl_status = 'With FTRL'
        else:
            ftrl_status = 'Without FTRL'
        
        # Create full config name
        full_config = f"{algorithm} {ftrl_status}"
        
        # Check if we should include this config
        if algorithms_to_plot is not None and full_config not in algorithms_to_plot:
            continue
        
        for run_result in runs:
            rows.append({
                'Algorithm': algorithm,
                'FTRL': ftrl_status,
                'Config': full_config,
                'Hypervolume': run_result['hypervolume'],
                'Runtime (s)': run_result['runtime'],
                'Solutions': run_result['num_solutions']
            })
    
    return pd.DataFrame(rows)


def get_color_map(df):
    """Generate color map for configurations"""
    unique_configs = df['Config'].unique()
    color_map = {}
    
    for config in unique_configs:
        if 'UCB' in config and 'With FTRL' in config:
            color_map[config] = PASTEL_COLORS['UCB With FTRL']
        elif 'UCB' in config and 'Without FTRL' in config:
            color_map[config] = PASTEL_COLORS['UCB Without FTRL']
        elif 'Thompson' in config and 'With FTRL' in config:
            color_map[config] = PASTEL_COLORS['Thompson With FTRL']
        elif 'Thompson' in config and 'Without FTRL' in config:
            color_map[config] = PASTEL_COLORS['Thompson Without FTRL']
    
    return color_map


# ============================================================================
# PLOT 1: SIMPLE BAR CHARTS
# ============================================================================

def plot_simple_bar_charts(
    yaml_file,
    output_dir='plots',
    algorithms_to_plot=None,
    metrics_to_plot=None,
    plot_mode='grouped',
    show_error_bars=True,
    show_values=True,
    figsize=None,
    dpi=300
):
    """
    Create simple, clean bar charts
    
    Parameters:
    -----------
    yaml_file : str
        Path to YAML file
    output_dir : str
        Directory to save plots
    algorithms_to_plot : list, optional
        List of algorithms to include:
        - 'UCB-Hedge With FTRL'
        - 'UCB-Hedge Without FTRL'
        - 'Thompson-Hedge With FTRL'
        - 'Thompson-Hedge Without FTRL'
    metrics_to_plot : list, optional
        Metrics to plot: ['Hypervolume', 'Runtime (s)', 'Solutions']
    plot_mode : str
        'grouped' or 'separate'
    show_error_bars : bool
        Show standard deviation error bars
    show_values : bool
        Show values on bars
    figsize : tuple, optional
        Figure size
    dpi : int
        Resolution
    """
    
    print("\n" + "="*80)
    print("CREATING SIMPLE BAR CHARTS")
    print("="*80)
    
    # Load and prepare data
    data = load_ftrl_yaml(yaml_file)
    df = yaml_to_dataframe(data, algorithms_to_plot)
    
    problem_type = data.get('problem_type', 'Unknown')
    problem_size = data.get('problem_size', 'Unknown')
    actual_size = data.get('actual_size', 'Unknown')
    
    print(f"Problem: {problem_type}{actual_size} ({problem_size})")
    print(f"Configurations: {df['Config'].unique()}")
    print(f"Total runs: {len(df)}")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Set default metrics
    if metrics_to_plot is None:
        metrics_to_plot = ['Hypervolume', 'Runtime (s)', 'Solutions']
    
    # Compute summary statistics
    summary_stats = df.groupby('Config').agg({
        'Hypervolume': ['mean', 'std'],
        'Runtime (s)': ['mean', 'std'],
        'Solutions': ['mean', 'std']
    }).reset_index()
    
    summary_stats.columns = ['Config'] + [f'{col[0]}_{col[1]}' for col in summary_stats.columns[1:]]
    
    # Get colors
    color_map = get_color_map(df)
    
    metric_info = {
        'Hypervolume': {
            'label': 'Hypervolume',
            'format': '.4f'
        },
        'Runtime (s)': {
            'label': 'Runtime (seconds)',
            'format': '.2f'
        },
        'Solutions': {
            'label': 'Number of Solutions',
            'format': '.1f'
        }
    }
    
    n_metrics = len(metrics_to_plot)
    
    if plot_mode == 'grouped':
        # Create grouped plot
        if figsize is None:
            figsize = (5 * n_metrics, 5)
        
        fig, axes = plt.subplots(1, n_metrics, figsize=figsize)
        if n_metrics == 1:
            axes = [axes]
        
        for idx, metric in enumerate(metrics_to_plot):
            ax = axes[idx]
            
            # Get data
            configs = summary_stats['Config'].values
            means = summary_stats[f'{metric}_mean'].values
            stds = summary_stats[f'{metric}_std'].values if show_error_bars else None
            
            # Get colors for bars
            colors = [color_map[config] for config in configs]
            
            # Create bars
            x_pos = np.arange(len(configs))
            bars = ax.bar(
                x_pos,
                means,
                yerr=stds if show_error_bars else None,
                color=colors,
                edgecolor='black',
                linewidth=1.5,
                capsize=5 if show_error_bars else 0,
                error_kw={'linewidth': 1.5, 'ecolor': 'black', 'capthick': 1.5},
                alpha=0.9,
                width=0.7
            )
            
            # Add value labels
            if show_values:
                for i, (bar, val, std) in enumerate(zip(bars, means, stds if show_error_bars else [0]*len(means))):
                    height = bar.get_height()
                    label_y = height + (std if show_error_bars else height * 0.03)
                    label_text = f'{val:{metric_info[metric]["format"]}}'
                    
                    ax.text(
                        bar.get_x() + bar.get_width()/2.,
                        label_y,
                        label_text,
                        ha='center',
                        va='bottom',
                        fontsize=9,
                        fontweight='bold'
                    )
            
            # Styling
            ax.set_ylabel(metric_info[metric]['label'], fontsize=11, fontweight='bold')
            ax.set_title(metric_info[metric]['label'], fontsize=12, fontweight='bold', pad=10)
            ax.set_xticks(x_pos)
            ax.set_xticklabels([c.replace(' ', '\n') for c in configs], fontsize=9)
            ax.grid(True, alpha=0.3, linestyle='--', axis='y')
            ax.set_axisbelow(True)
        
        fig.suptitle(
            f'FTRL Ablation: {problem_type}{actual_size}',
            fontsize=14,
            fontweight='bold',
            y=0.98
        )
        
        plt.tight_layout()
        
        plot_file = os.path.join(output_dir, f'bar_chart_grouped_{problem_type}_{problem_size}.png')
        plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
        plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
        print(f"✓ Saved: {plot_file}")
        plt.close()
    
    else:  # separate mode
        if figsize is None:
            figsize = (8, 5)
        
        for metric in metrics_to_plot:
            fig, ax = plt.subplots(1, 1, figsize=figsize)
            
            # Get data
            configs = summary_stats['Config'].values
            means = summary_stats[f'{metric}_mean'].values
            stds = summary_stats[f'{metric}_std'].values if show_error_bars else None
            
            # Get colors
            colors = [color_map[config] for config in configs]
            
            # Create bars
            x_pos = np.arange(len(configs))
            bars = ax.bar(
                x_pos,
                means,
                yerr=stds if show_error_bars else None,
                color=colors,
                edgecolor='black',
                linewidth=1.5,
                capsize=5 if show_error_bars else 0,
                error_kw={'linewidth': 1.5, 'ecolor': 'black', 'capthick': 1.5},
                alpha=0.9,
                width=0.7
            )
            
            # Add value labels
            if show_values:
                for i, (bar, val, std) in enumerate(zip(bars, means, stds if show_error_bars else [0]*len(means))):
                    height = bar.get_height()
                    label_y = height + (std if show_error_bars else height * 0.03)
                    label_text = f'{val:{metric_info[metric]["format"]}}'
                    
                    ax.text(
                        bar.get_x() + bar.get_width()/2.,
                        label_y,
                        label_text,
                        ha='center',
                        va='bottom',
                        fontsize=10,
                        fontweight='bold'
                    )
            
            # Styling
            ax.set_ylabel(metric_info[metric]['label'], fontsize=12, fontweight='bold')
            ax.set_title(
                f'{metric_info[metric]["label"]}: {problem_type}{actual_size}',
                fontsize=13,
                fontweight='bold',
                pad=15
            )
            ax.set_xticks(x_pos)
            ax.set_xticklabels([c.replace(' ', '\n') for c in configs], fontsize=10)
            ax.grid(True, alpha=0.3, linestyle='--', axis='y')
            ax.set_axisbelow(True)
            
            plt.tight_layout()
            
            plot_file = os.path.join(
                output_dir,
                f'bar_chart_{metric.replace(" ", "_").replace("(", "").replace(")", "")}_{problem_type}_{problem_size}.png'
            )
            plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
            plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
            print(f"✓ Saved: {plot_file}")
            plt.close()
    
    print("\n✓ Bar charts completed!")


# ============================================================================
# PLOT 2: BOX PLOTS WITH SIGNIFICANCE
# ============================================================================

def plot_boxplots_with_significance(
    yaml_file,
    output_dir='plots',
    algorithms_to_plot=None,
    metrics_to_plot=None,
    figsize=None,
    dpi=300
):
    """
    Create box plots with statistical significance tests
    
    Parameters:
    -----------
    yaml_file : str
        Path to YAML file
    output_dir : str
        Directory to save plots
    algorithms_to_plot : list, optional
        List of algorithms to include
    metrics_to_plot : list, optional
        Metrics to plot
    figsize : tuple, optional
        Figure size
    dpi : int
        Resolution
    """
    
    print("\n" + "="*80)
    print("CREATING BOX PLOTS WITH SIGNIFICANCE TESTS")
    print("="*80)
    
    # Load data
    data = load_ftrl_yaml(yaml_file)
    df = yaml_to_dataframe(data, algorithms_to_plot)
    
    problem_type = data.get('problem_type', 'Unknown')
    problem_size = data.get('problem_size', 'Unknown')
    actual_size = data.get('actual_size', 'Unknown')
    
    print(f"Problem: {problem_type}{actual_size} ({problem_size})")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Set default metrics
    if metrics_to_plot is None:
        metrics_to_plot = ['Hypervolume', 'Runtime (s)', 'Solutions']
    
    metric_info = {
        'Hypervolume': 'Hypervolume',
        'Runtime (s)': 'Runtime (seconds)',
        'Solutions': 'Number of Solutions'
    }
    
    # Set figure size
    if figsize is None:
        figsize = (5 * len(metrics_to_plot), 5)
    
    fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=figsize)
    if len(metrics_to_plot) == 1:
        axes = [axes]
    
    # Get unique algorithms
    unique_algos = df['Algorithm'].unique()
    
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes[idx]
        
        # Create box plot
        bp = sns.boxplot(
            data=df,
            x='Algorithm',
            y=metric,
            hue='FTRL',
            ax=ax,
            palette={
                'With FTRL': '#A8DADC',
                'Without FTRL': '#F4A261'
            },
            linewidth=1.5,
            width=0.6
        )
        
        # Add strip plot
        sns.stripplot(
            data=df,
            x='Algorithm',
            y=metric,
            hue='FTRL',
            ax=ax,
            dodge=True,
            alpha=0.4,
            size=4,
            palette={
                'With FTRL': '#457B9D',
                'Without FTRL': '#E76F51'
            },
            legend=False
        )
        
        # Statistical significance tests
        y_max = df[metric].max()
        y_min = df[metric].min()
        y_range = y_max - y_min
        
        for i, algo in enumerate(unique_algos):
            with_ftrl = df[(df['Algorithm'] == algo) & (df['FTRL'] == 'With FTRL')][metric]
            without_ftrl = df[(df['Algorithm'] == algo) & (df['FTRL'] == 'Without FTRL')][metric]
            
            if len(with_ftrl) > 0 and len(without_ftrl) > 0:
                # Mann-Whitney U test
                _, p_value = stats.mannwhitneyu(with_ftrl, without_ftrl, alternative='two-sided')
                
                # Determine significance
                if p_value < 0.001:
                    sig_text = '***'
                elif p_value < 0.01:
                    sig_text = '**'
                elif p_value < 0.05:
                    sig_text = '*'
                else:
                    sig_text = 'n.s.'
                
                # Draw significance bar
                x1 = i - 0.2
                x2 = i + 0.2
                h = y_max + 0.05 * y_range + i * 0.08 * y_range
                
                ax.plot([x1, x1, x2, x2], [h, h + 0.02*y_range, h + 0.02*y_range, h], 
                       'k-', linewidth=1.5)
                ax.text((x1 + x2)/2, h + 0.02*y_range, sig_text,
                       ha='center', va='bottom', fontsize=10, fontweight='bold')
        
        # Styling
        ax.set_ylabel(metric_info[metric], fontsize=11, fontweight='bold')
        ax.set_xlabel('')
        ax.set_title(metric_info[metric], fontsize=12, fontweight='bold', pad=10)
        ax.legend(title='FTRL Status', frameon=True, loc='best', fontsize=9)
        ax.grid(True, alpha=0.3, linestyle='--', axis='y')
        ax.set_axisbelow(True)
    
    fig.suptitle(
        f'Statistical Comparison: {problem_type}{actual_size}\n(*, **, *** = p<0.05, 0.01, 0.001; n.s. = not significant)',
        fontsize=13,
        fontweight='bold',
        y=1.02
    )
    
    plt.tight_layout()
    
    plot_file = os.path.join(output_dir, f'boxplot_significance_{problem_type}_{problem_size}.png')
    plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"✓ Saved: {plot_file}")
    plt.close()
    
    print("\n✓ Box plots completed!")


# ============================================================================
# PLOT 3: VIOLIN PLOTS
# ============================================================================

def plot_violin_distributions(
    yaml_file,
    output_dir='plots',
    algorithms_to_plot=None,
    metrics_to_plot=None,
    figsize=None,
    dpi=300
):
    """
    Create violin plots showing distributions
    
    Parameters:
    -----------
    yaml_file : str
        Path to YAML file
    output_dir : str
        Directory to save plots
    algorithms_to_plot : list, optional
        List of algorithms to include
    metrics_to_plot : list, optional
        Metrics to plot
    figsize : tuple, optional
        Figure size
    dpi : int
        Resolution
    """
    
    print("\n" + "="*80)
    print("CREATING VIOLIN PLOTS")
    print("="*80)
    
    # Load data
    data = load_ftrl_yaml(yaml_file)
    df = yaml_to_dataframe(data, algorithms_to_plot)
    
    problem_type = data.get('problem_type', 'Unknown')
    problem_size = data.get('problem_size', 'Unknown')
    actual_size = data.get('actual_size', 'Unknown')
    
    print(f"Problem: {problem_type}{actual_size} ({problem_size})")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Set default metrics
    if metrics_to_plot is None:
        metrics_to_plot = ['Hypervolume', 'Runtime (s)', 'Solutions']
    
    metric_info = {
        'Hypervolume': 'Hypervolume',
        'Runtime (s)': 'Runtime (seconds)',
        'Solutions': 'Number of Solutions'
    }
    
    # Set figure size
    if figsize is None:
        figsize = (5 * len(metrics_to_plot), 5)
    
    fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=figsize)
    if len(metrics_to_plot) == 1:
        axes = [axes]
    
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes[idx]
        
        # Create violin plot
        sns.violinplot(
            data=df,
            x='Algorithm',
            y=metric,
            hue='FTRL',
            ax=ax,
            palette={
                'With FTRL': '#A8DADC',
                'Without FTRL': '#F4A261'
            },
            split=False,
            inner='quartile',
            linewidth=1.5,
            alpha=0.8
        )
        
        # Styling
        ax.set_ylabel(metric_info[metric], fontsize=11, fontweight='bold')
        ax.set_xlabel('')
        ax.set_title(metric_info[metric], fontsize=12, fontweight='bold', pad=10)
        ax.legend(title='FTRL Status', frameon=True, loc='best', fontsize=9)
        ax.grid(True, alpha=0.3, linestyle='--', axis='y')
        ax.set_axisbelow(True)
    
    fig.suptitle(
        f'Distribution Analysis: {problem_type}{actual_size}',
        fontsize=13,
        fontweight='bold',
        y=0.98
    )
    
    plt.tight_layout()
    
    plot_file = os.path.join(output_dir, f'violin_distributions_{problem_type}_{problem_size}.png')
    plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"✓ Saved: {plot_file}")
    plt.close()
    
    print("\n✓ Violin plots completed!")


# ============================================================================
# PLOT 4: IMPROVEMENT PERCENTAGE
# ============================================================================

def plot_improvement_percentage(
    yaml_file,
    output_dir='plots',
    algorithms_to_plot=None,
    figsize=(10, 6),
    dpi=300
):
    """
    Create improvement percentage bar chart
    
    Parameters:
    -----------
    yaml_file : str
        Path to YAML file
    output_dir : str
        Directory to save plots
    algorithms_to_plot : list, optional
        List of algorithms to include (must include both With/Without FTRL pairs)
    figsize : tuple
        Figure size
    dpi : int
        Resolution
    """
    
    print("\n" + "="*80)
    print("CREATING IMPROVEMENT PERCENTAGE CHART")
    print("="*80)
    
    # Load data
    data = load_ftrl_yaml(yaml_file)
    df = yaml_to_dataframe(data, algorithms_to_plot)
    
    problem_type = data.get('problem_type', 'Unknown')
    problem_size = data.get('problem_size', 'Unknown')
    actual_size = data.get('actual_size', 'Unknown')
    
    print(f"Problem: {problem_type}{actual_size} ({problem_size})")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Calculate improvements
    improvements = []
    
    for algo in df['Algorithm'].unique():
        with_ftrl = df[(df['Algorithm'] == algo) & (df['FTRL'] == 'With FTRL')]
        without_ftrl = df[(df['Algorithm'] == algo) & (df['FTRL'] == 'Without FTRL')]
        
        if len(with_ftrl) == 0 or len(without_ftrl) == 0:
            continue
        
        for metric in ['Hypervolume', 'Runtime (s)', 'Solutions']:
            mean_with = with_ftrl[metric].mean()
            mean_without = without_ftrl[metric].mean()
            
            # For runtime, lower is better (so we reverse the calculation)
            if metric == 'Runtime (s)':
                improvement = (mean_without - mean_with) / mean_without * 100
            else:
                improvement = (mean_with - mean_without) / mean_without * 100
            
            improvements.append({
                'Algorithm': algo,
                'Metric': metric.replace(' (s)', ''),
                'Improvement (%)': improvement
            })
    
    imp_df = pd.DataFrame(improvements)
    
    # Create plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Pivot for grouped bar chart
    pivot_df = imp_df.pivot(index='Metric', columns='Algorithm', values='Improvement (%)')
    
    # Plot
    pivot_df.plot(
        kind='bar',
        ax=ax,
        color=['#A8DADC', '#E9C46A'],
        width=0.7,
        edgecolor='black',
        linewidth=1.5,
        alpha=0.9
    )
    
    # Add zero line
    ax.axhline(0, color='black', linewidth=1.5, linestyle='-', zorder=1)
    
    # Styling
    ax.set_ylabel('Improvement with FTRL (%)', fontsize=12, fontweight='bold')
    ax.set_xlabel('Metric', fontsize=12, fontweight='bold')
    ax.set_title(
        f'FTRL Impact: {problem_type}{actual_size}\n(Positive = FTRL improves performance)',
        fontsize=13,
        fontweight='bold',
        pad=15
    )
    ax.legend(title='Algorithm', frameon=True, fontsize=10, loc='best')
    ax.grid(True, alpha=0.3, linestyle='--', axis='y')
    ax.set_axisbelow(True)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
    
    # Add value labels
    for container in ax.containers:
        ax.bar_label(container, fmt='%.2f%%', fontsize=9, fontweight='bold')
    
    plt.tight_layout()
    
    plot_file = os.path.join(output_dir, f'improvement_percentage_{problem_type}_{problem_size}.png')
    plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"✓ Saved: {plot_file}")
    plt.close()
    
    print("\n✓ Improvement chart completed!")
    
    # Print improvement summary
    print("\nImprovement Summary:")
    print("-" * 60)
    for _, row in imp_df.iterrows():
        print(f"{row['Algorithm']:20s} {row['Metric']:15s}: {row['Improvement (%)']:+7.2f}%")


# ============================================================================
# PLOT 5: PARETO SCATTER
# ============================================================================

def plot_pareto_scatter(
    yaml_file,
    output_dir='plots',
    algorithms_to_plot=None,
    figsize=(8, 6),
    dpi=300
):
    """
    Create Pareto-style scatter plot (Hypervolume vs Runtime)
    
    Parameters:
    -----------
    yaml_file : str
        Path to YAML file
    output_dir : str
        Directory to save plots
    algorithms_to_plot : list, optional
        List of algorithms to include
    figsize : tuple
        Figure size
    dpi : int
        Resolution
    """
    
    print("\n" + "="*80)
    print("CREATING PARETO SCATTER PLOT")
    print("="*80)
    
    # Load data
    data = load_ftrl_yaml(yaml_file)
    df = yaml_to_dataframe(data, algorithms_to_plot)
    
    problem_type = data.get('problem_type', 'Unknown')
    problem_size = data.get('problem_size', 'Unknown')
    actual_size = data.get('actual_size', 'Unknown')
    
    print(f"Problem: {problem_type}{actual_size} ({problem_size})")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Create plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Get color map
    color_map = get_color_map(df)
    
    # Marker styles
    marker_map = {
        'UCB-Hedge With FTRL': 'o',
        'UCB-Hedge Without FTRL': 's',
        'Thompson-Hedge With FTRL': '^',
        'Thompson-Hedge Without FTRL': 'v'
    }
    
    # Plot each configuration
    for config in df['Config'].unique():
        subset = df[df['Config'] == config]
        
        ax.scatter(
            subset['Runtime (s)'],
            subset['Hypervolume'],
            c=color_map[config],
            marker=marker_map.get(config, 'o'),
            s=80,
            alpha=0.6,
            edgecolors='black',
            linewidth=1.0,
            label=config
        )
        
        # Plot mean
        mean_x = subset['Runtime (s)'].mean()
        mean_y = subset['Hypervolume'].mean()
        
        ax.scatter(
            mean_x,
            mean_y,
            c=color_map[config],
            marker='*',
            s=400,
            edgecolors='black',
            linewidth=2.0,
            zorder=10
        )
    
    # Styling
    ax.set_xlabel('Runtime (seconds)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Hypervolume', fontsize=12, fontweight='bold')
    ax.set_title(
        f'Performance Trade-off: {problem_type}{actual_size}\n(Stars = mean values)',
        fontsize=13,
        fontweight='bold',
        pad=15
    )
    ax.legend(frameon=True, loc='best', fontsize=9)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    plt.tight_layout()
    
    plot_file = os.path.join(output_dir, f'pareto_scatter_{problem_type}_{problem_size}.png')
    plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"✓ Saved: {plot_file}")
    plt.close()
    
    print("\n✓ Pareto scatter completed!")


# ============================================================================
# CONVENIENCE FUNCTION: GENERATE ALL PLOTS
# ============================================================================

def generate_all_plots(
    yaml_file,
    output_dir='all_plots',
    algorithms_to_plot=None,
    dpi=300
):
    """
    Generate all plot types in one go
    
    Parameters:
    -----------
    yaml_file : str
        Path to YAML file
    output_dir : str
        Directory to save all plots
    algorithms_to_plot : list, optional
        List of algorithms to include
    dpi : int
        Resolution
    """
    
    print("\n" + "="*80)
    print("GENERATING ALL FTRL ABLATION PLOTS")
    print("="*80)
    
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. Simple bar charts
    plot_simple_bar_charts(
        yaml_file=yaml_file,
        output_dir=output_dir,
        algorithms_to_plot=algorithms_to_plot,
        plot_mode='grouped',
        dpi=dpi
    )
    
    # 2. Box plots with significance
    plot_boxplots_with_significance(
        yaml_file=yaml_file,
        output_dir=output_dir,
        algorithms_to_plot=algorithms_to_plot,
        dpi=dpi
    )
    
    # 3. Violin plots
    plot_violin_distributions(
        yaml_file=yaml_file,
        output_dir=output_dir,
        algorithms_to_plot=algorithms_to_plot,
        dpi=dpi
    )
    
    # 4. Improvement percentage
    plot_improvement_percentage(
        yaml_file=yaml_file,
        output_dir=output_dir,
        algorithms_to_plot=algorithms_to_plot,
        dpi=dpi
    )
    
    # 5. Pareto scatter
    plot_pareto_scatter(
        yaml_file=yaml_file,
        output_dir=output_dir,
        algorithms_to_plot=algorithms_to_plot,
        dpi=dpi
    )
    
    print("\n" + "="*80)
    print("✓ ALL PLOTS COMPLETED SUCCESSFULLY!")
    print("="*80)
    print(f"All plots saved to: {output_dir}/")


# ============================================================================
# USAGE EXAMPLES
# ============================================================================

if __name__ == "__main__":
    
    yaml_file = 'ablation_results_experts/ablation2_ftrl_BiTSP_large_20251114-051433.yaml'
    
    # Example 1: Generate ALL plots for ALL algorithms
    print("\n" + "="*80)
    print("EXAMPLE 1: All algorithms, all plots")
    print("="*80)
    generate_all_plots(
        yaml_file=yaml_file,
        output_dir='ftrl_plots_all_UCB',
        # algorithms_to_plot=None,  # All algorithms
        algorithms_to_plot=[
            'UCB-Hedge With FTRL',
            'UCB-Hedge Without FTRL',
            # 'Thompson-Hedge With FTRL',
            # 'Thompson-Hedge Without FTRL'
        ],
        dpi=300
    )
    
    # Example 2: Only UCB algorithms
    # print("\n" + "="*80)
    # print("EXAMPLE 2: UCB algorithms only")
    # print("="*80)
    # generate_all_plots(
    #     yaml_file=yaml_file,
    #     output_dir='ftrl_plots_ucb',
    #     algorithms_to_plot=[
    #         'UCB-Hedge With FTRL',
    #         'UCB-Hedge Without FTRL'
    #     ],
    #     dpi=300
    # )
    
    # Example 3: Only Thompson algorithms
    # print("\n" + "="*80)
    # print("EXAMPLE 3: Thompson algorithms only")
    # print("="*80)
    # generate_all_plots(
    #     yaml_file=yaml_file,
    #     output_dir='ftrl_plots_thompson',
    #     algorithms_to_plot=[
    #         'Thompson-Hedge With FTRL',
    #         'Thompson-Hedge Without FTRL'
    #     ],
    #     dpi=300
    # )
    
    # Example 4: Individual plot types with custom settings
    # print("\n" + "="*80)
    # print("EXAMPLE 4: Custom individual plots")
    # print("="*80)
    
    # # Just bar charts, separate files
    # plot_simple_bar_charts(
    #     yaml_file=yaml_file,
    #     output_dir='ftrl_custom_bars',
    #     algorithms_to_plot=None,
    #     metrics_to_plot=['Hypervolume', 'Runtime (s)'],
    #     plot_mode='separate',
    #     show_error_bars=True,
    #     show_values=True,
    #     dpi=300
    # )
    
    # Just improvement chart
    # plot_improvement_percentage(
    #     yaml_file=yaml_file,
    #     output_dir='ftrl_custom_improvement',
    #     algorithms_to_plot=None,
    #     dpi=300
    # )