"""
FTRL Paper Plotting Script (Post-hoc)

Reads YAML results from FTRL ablation study and generates publication-ready figures.

Layout:
- Figure 1 (Rate Study): 2 rows × 3 cols - Mean±Std, HV Distribution, Worst-case
- Figure 2 (Variance): 2 rows × 4 cols - Box, Violin, CV, Variance Reduction %
- Figure 3 (Stress Tests): 1 row × 3 cols - Mean Heatmap, Std Heatmap, Worst-case
- Figure 4 (Regret): 2 rows × 3 cols - Cumulative Regret, Convergence, Stability

Author: Research Team
Date: 2024
"""

import os
import yaml
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any, List, Tuple, Optional
from pathlib import Path
import warnings

warnings.filterwarnings('ignore')


# ============================================================================
# ICML STYLE CONFIGURATION
# ============================================================================

def setup_icml_style():
    """Configure matplotlib for ICML paper style"""
    plt.style.use('seaborn-v0_8-paper')
    
    # Colorblind-friendly palette
    colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161', '#949494']
    sns.set_palette(sns.color_palette(colors))
    
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica'],
        'font.size': 12,           # Base font size (was 9)
        'axes.labelsize': 14,      # Axis labels (was 10)
        'axes.titlesize': 15,      # Subplot titles (was 11)
        'xtick.labelsize': 11,     # X tick labels (was 8)
        'ytick.labelsize': 11,     # Y tick labels (was 8)
        'legend.fontsize': 11,     # Legend text (was 8)
        'figure.titlesize': 16,    # Figure suptitle (was 12)
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'lines.linewidth': 2.0,    # Thicker lines (was 1.5)
        'lines.markersize': 8,     # Larger markers (was 6)
        'legend.frameon': True,
        'legend.framealpha': 0.9,
    })


# ============================================================================
# YAML LOADING UTILITIES
# ============================================================================

class YAMLLoader:
    """Utility class to find and load YAML result files"""
    
    # Map problem size names to actual n values
    SIZE_TO_N = {
        'BiTSP': {'small': 20, 'medium': 50, 'large': 100},
        'BiKP': {'small': 50, 'medium': 100, 'large': 200},
        'TriTSP': {'small': 20, 'medium': 50, 'large': 100}
    }
    
    def __init__(self, yaml_dir: str):
        self.yaml_dir = Path(yaml_dir)
        if not self.yaml_dir.exists():
            raise ValueError(f"YAML directory does not exist: {yaml_dir}")
    
    def find_yaml(self, study: str, problem_type: str, problem_size: str) -> Optional[Path]:
        """
        Find YAML file matching study, problem type, and size.
        
        Parameters:
        -----------
        study : str
            One of 'rate', 'variance', 'stress', 'regret'
        problem_type : str
            e.g., 'BiTSP', 'BiKP'
        problem_size : str
            e.g., 'small', 'medium', 'large'
        
        Returns:
        --------
        Path to YAML file or None if not found
        """
        # Map study names to file prefixes
        study_prefixes = {
            'rate': 'ftrl_rate_study',
            'variance': 'variance_analysis',
            'stress': 'stress_tests',
            'regret': 'regret_analysis'
        }
        
        prefix = study_prefixes.get(study, study)
        
        # Search patterns - try multiple variations
        patterns = [
            f"{prefix}_{problem_type}_{problem_size}_*.yaml",
            f"{prefix}_{problem_type}_{problem_size}.yaml",
            f"{prefix}_{problem_type}_*.yaml",  # Fallback for studies without size in name
        ]
        
        for pattern in patterns:
            matches = list(self.yaml_dir.glob(pattern))
            if matches:
                # Return most recent file if multiple matches
                matches.sort(key=lambda x: x.stat().st_mtime, reverse=True)
                return matches[0]
        
        return None
    
    def load_yaml(self, filepath: Path) -> Dict[str, Any]:
        """Load YAML file and return contents"""
        with open(filepath, 'r') as f:
            return yaml.safe_load(f)
    
    def get_n_value(self, problem_type: str, problem_size: str) -> int:
        """Get actual n value for problem type and size"""
        return self.SIZE_TO_N.get(problem_type, {}).get(problem_size, 50)
    
    def list_available_files(self) -> List[str]:
        """List all YAML files in directory"""
        return [f.name for f in self.yaml_dir.glob("*.yaml")]


# ============================================================================
# MAIN PLOTTER CLASS
# ============================================================================

class FTRLPaperPlotter:
    """
    Generate publication-ready figures from FTRL ablation study results.
    
    Parameters:
    -----------
    yaml_dir : str
        Directory containing YAML result files
    output_dir : str
        Directory to save figures
    problem_type : str
        Problem type ('BiTSP', 'BiKP', 'TriTSP')
    scales : List[str]
        Two scales to compare, e.g., ['small', 'large']
    """
    
    # Color scheme
    COLORS = {
        'UCB_with_FTRL': '#0173B2',
        'UCB_without_FTRL': '#DE8F05',
        'Thompson_with_FTRL': '#029E73',
        'Thompson_without_FTRL': '#CC78BC',
        'UCB': '#0173B2',
        'Thompson': '#029E73',
        'With FTRL': '#0173B2',
        'Without FTRL': '#DE8F05'
    }
    
    def __init__(
        self,
        yaml_dir: str,
        output_dir: str,
        problem_type: str = 'BiTSP',
        scales: List[str] = None
    ):
        self.yaml_loader = YAMLLoader(yaml_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.problem_type = problem_type
        self.scales = scales or ['small', 'large']
        
        if len(self.scales) != 2:
            raise ValueError(f"Exactly 2 scales required, got {len(self.scales)}")
        
        # Setup plotting style
        setup_icml_style()
        
        # Cache loaded data
        self._data_cache = {}
        
        print(f"FTRLPaperPlotter initialized:")
        print(f"  YAML dir: {yaml_dir}")
        print(f"  Output dir: {output_dir}")
        print(f"  Problem type: {problem_type}")
        print(f"  Scales: {self.scales}")
        print(f"  Available YAMLs: {self.yaml_loader.list_available_files()}")
    
    def _load_study_data(self, study: str, problem_size: str) -> Optional[Dict]:
        """Load data for a specific study and size"""
        cache_key = f"{study}_{problem_size}"
        
        if cache_key not in self._data_cache:
            filepath = self.yaml_loader.find_yaml(study, self.problem_type, problem_size)
            if filepath:
                print(f"  Loading: {filepath.name}")
                self._data_cache[cache_key] = self.yaml_loader.load_yaml(filepath)
            else:
                print(f"  Warning: No YAML found for {study}/{self.problem_type}/{problem_size}")
                self._data_cache[cache_key] = None
        
        return self._data_cache[cache_key]
    
    def _get_n_label(self, problem_size: str) -> str:
        """Get label with n value for a scale"""
        n = self.yaml_loader.get_n_value(self.problem_type, problem_size)
        return f"n={n}"
    
    def _save_figure(self, fig, name: str):
        """Save figure as both PNG and PDF"""
        png_path = self.output_dir / f"{name}.png"
        pdf_path = self.output_dir / f"{name}.pdf"
        
        fig.savefig(png_path, dpi=300, bbox_inches='tight')
        fig.savefig(pdf_path, bbox_inches='tight')
        plt.close(fig)
        
        print(f"  Saved: {png_path.name}, {pdf_path.name}")
    
    # =========================================================================
    # FIGURE 1: RATE STUDY (2 rows × 3 cols)
    # =========================================================================
    
    def plot_rate_study(self):
        """
        Figure 1: FTRL Rate Study
        
        Layout: 2 rows (scales) × 3 cols (Mean±Std, HV Distrib, Worst-case)
        """
        print("\nGenerating Figure 1: Rate Study...")
        
        fig, axes = plt.subplots(2, 3, figsize=(14, 8))
        
        for row_idx, scale in enumerate(self.scales):
            data = self._load_study_data('rate', scale)
            n_label = self._get_n_label(scale)
            
            if data is None:
                for col in range(3):
                    axes[row_idx, col].text(0.5, 0.5, f'No data for {scale}',
                                           ha='center', va='center', transform=axes[row_idx, col].transAxes)
                    axes[row_idx, col].set_ylabel(n_label)
                continue
            
            results = data.get('results', {})
            ftrl_rates = data.get('ftrl_rates', [0.0, 0.3, 0.5, 0.7, 1.0])
            
            # Prepare data for plotting
            plot_data = {'UCB': [], 'Thompson': []}
            
            for rate in ftrl_rates:
                for alg in ['UCB', 'Thompson']:
                    config_key = f"{alg}_rate{rate}"
                    if config_key in results:
                        hvs = results[config_key].get('hypervolume', [])
                        if hvs:
                            plot_data[alg].append({
                                'rate': rate,
                                'mean': np.mean(hvs),
                                'std': np.std(hvs),
                                'min': np.min(hvs),
                                'max': np.max(hvs),
                                'hvs': hvs
                            })
            
            # Column 0: Mean ± Std vs FTRL Rate
            ax0 = axes[row_idx, 0]
            for alg, color in [('UCB', self.COLORS['UCB']), ('Thompson', self.COLORS['Thompson'])]:
                if plot_data[alg]:
                    rates = [d['rate'] * 100 for d in plot_data[alg]]
                    means = [d['mean'] for d in plot_data[alg]]
                    stds = [d['std'] for d in plot_data[alg]]
                    ax0.errorbar(rates, means, yerr=stds, marker='o', label=alg,
                               color=color, capsize=3, linewidth=2)
            
            ax0.set_xlabel('FTRL Usage Rate (%)')
            ax0.set_ylabel('Hypervolume')
            ax0.set_title('Mean ± Std' if row_idx == 0 else '')
            ax0.legend(loc='best')
            ax0.grid(True, alpha=0.3)
            
            # Add scale label on left
            ax0.annotate(n_label, xy=(-0.25, 0.5), xycoords='axes fraction',
                        fontsize=14, fontweight='bold', ha='center', va='center',
                        rotation=90)
            
            # Column 1: HV Distribution (box/violin plot)
            ax1 = axes[row_idx, 1]
            box_data = []
            for rate in ftrl_rates:
                for alg in ['UCB', 'Thompson']:
                    config_key = f"{alg}_rate{rate}"
                    if config_key in results:
                        hvs = results[config_key].get('hypervolume', [])
                        for hv in hvs:
                            box_data.append({
                                'FTRL Rate': f"{rate*100:.0f}%",
                                'Algorithm': alg,
                                'Hypervolume': hv
                            })
            
            if box_data:
                df_box = pd.DataFrame(box_data)
                sns.boxplot(data=df_box, x='FTRL Rate', y='Hypervolume', hue='Algorithm',
                           ax=ax1, palette={'UCB': self.COLORS['UCB'], 'Thompson': self.COLORS['Thompson']})
                ax1.set_title('HV Distribution' if row_idx == 0 else '')
                ax1.legend(loc='best')
            
            # Column 2: Worst-case (min) performance
            ax2 = axes[row_idx, 2]
            for alg, color in [('UCB', self.COLORS['UCB']), ('Thompson', self.COLORS['Thompson'])]:
                if plot_data[alg]:
                    rates = [d['rate'] * 100 for d in plot_data[alg]]
                    mins = [d['min'] for d in plot_data[alg]]
                    ax2.plot(rates, mins, marker='^', label=alg, color=color, linewidth=2)
            
            ax2.set_xlabel('FTRL Usage Rate (%)')
            ax2.set_ylabel('Min Hypervolume')
            ax2.set_title('Worst-Case' if row_idx == 0 else '')
            ax2.legend(loc='best')
            ax2.grid(True, alpha=0.3)
        
        fig.suptitle(f'FTRL Rate Study: {self.problem_type}', fontsize=16, fontweight='bold')
        plt.tight_layout(rect=[0.03, 0, 1.2, 0.96])
        
        self._save_figure(fig, f'fig1_rate_study_{self.problem_type}')
    
    # =========================================================================
    # FIGURE 2: VARIANCE ANALYSIS (2 rows × 4 cols)
    # =========================================================================
    
    def plot_variance(self):
        """
        Figure 2: Variance Analysis
        
        Layout: 2 rows (scales) × 4 cols (Box, Violin, CV, Variance Reduction %)
        """
        print("\nGenerating Figure 2: Variance Analysis...")
        
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        
        for row_idx, scale in enumerate(self.scales):
            data = self._load_study_data('variance', scale)
            n_label = self._get_n_label(scale)
            
            if data is None:
                for col in range(4):
                    axes[row_idx, col].text(0.5, 0.5, f'No data for {scale}',
                                           ha='center', va='center', transform=axes[row_idx, col].transAxes)
                continue
            
            raw_results = data.get('raw_results', {})
            variance_stats = data.get('variance_statistics', {})
            
            # Prepare DataFrame for plotting
            plot_data = []
            for config_name, config_data in raw_results.items():
                algorithm = 'UCB' if 'UCB' in config_name else 'Thompson'
                ftrl = 'Without FTRL' if 'without_FTRL' in config_name else 'With FTRL'
                
                hvs = config_data.get('hypervolume', [])
                for hv in hvs:
                    plot_data.append({
                        'Algorithm': algorithm,
                        'FTRL': ftrl,
                        'Config': f"{algorithm}\n{ftrl}",
                        'Hypervolume': hv
                    })
            
            df = pd.DataFrame(plot_data)
            
            # Column 0: Box plot
            ax0 = axes[row_idx, 0]
            if len(df) > 0:
                sns.boxplot(data=df, x='Algorithm', y='Hypervolume', hue='FTRL', ax=ax0,
                           palette={'With FTRL': self.COLORS['With FTRL'], 
                                   'Without FTRL': self.COLORS['Without FTRL']})
                ax0.set_title('Box Plot' if row_idx == 0 else '')
                ax0.legend(loc='best', title='')
            
            ax0.annotate(n_label, xy=(-0.3, 0.5), xycoords='axes fraction',
                        fontsize=14, fontweight='bold', ha='center', va='center',
                        rotation=90)
            
            # Column 1: Violin plot
            ax1 = axes[row_idx, 1]
            if len(df) > 0:
                sns.violinplot(data=df, x='Algorithm', y='Hypervolume', hue='FTRL', ax=ax1,
                              palette={'With FTRL': self.COLORS['With FTRL'],
                                      'Without FTRL': self.COLORS['Without FTRL']},
                              split=True)
                ax1.set_title('Violin Plot' if row_idx == 0 else '')
                ax1.legend(loc='best', title='')
            
            # Column 2: Coefficient of Variation
            ax2 = axes[row_idx, 2]
            cv_data = []
            for config_name, stats in variance_stats.items():
                if isinstance(stats, dict) and 'cv' in stats:
                    algorithm = 'UCB' if 'UCB' in config_name else 'Thompson'
                    ftrl = 'Without FTRL' if 'without' in config_name else 'With FTRL'
                    cv_data.append({
                        'Algorithm': algorithm,
                        'FTRL': ftrl,
                        'CV': stats['cv']
                    })
            
            if cv_data:
                df_cv = pd.DataFrame(cv_data)
                sns.barplot(data=df_cv, x='Algorithm', y='CV', hue='FTRL', ax=ax2,
                           palette={'With FTRL': self.COLORS['With FTRL'],
                                   'Without FTRL': self.COLORS['Without FTRL']})
                ax2.set_ylabel('Coefficient of Variation')
                ax2.set_title('CV (lower=better)' if row_idx == 0 else '')
                ax2.legend(loc='best', title='')
            
            # Column 3: Variance Reduction %
            ax3 = axes[row_idx, 3]
            reduction_data = []
            for key, value in variance_stats.items():
                if 'variance_reduction' in key:
                    alg = key.split('_')[0]
                    reduction_data.append({
                        'Algorithm': alg,
                        'Variance Reduction (%)': value
                    })
            
            if reduction_data:
                df_red = pd.DataFrame(reduction_data)
                colors = ['#029E73' if v >= 0 else '#CC78BC' 
                         for v in df_red['Variance Reduction (%)']]
                bars = ax3.bar(df_red['Algorithm'], df_red['Variance Reduction (%)'],
                              color=colors)
                ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
                ax3.set_ylabel('Variance Reduction (%)')
                ax3.set_title('FTRL Effect' if row_idx == 0 else '')
                
                # Add value labels
                for bar, val in zip(bars, df_red['Variance Reduction (%)']):
                    ax3.text(bar.get_x() + bar.get_width()/2, 
                            bar.get_height() + (2 if val >= 0 else -5),
                            f'{val:.1f}%', ha='center', va='bottom' if val >= 0 else 'top',
                            fontweight='bold', fontsize=11)
        
        fig.suptitle(f'Variance Analysis: {self.problem_type}', fontsize=16, fontweight='bold')
        plt.tight_layout(rect=[0.03, 0, 1.2, 1.0])
        
        self._save_figure(fig, f'fig2_variance_{self.problem_type}')
    
    # =========================================================================
    # FIGURE 3: STRESS TESTS (2 rows × 3 cols)
    # =========================================================================
    
    def plot_stress_tests(self):
        """
        Figure 3: Stress Tests
        
        Layout: 2 rows (scales) × 3 cols (Mean Heatmap, Std Heatmap, Worst-case)
        """
        print("\nGenerating Figure 3: Stress Tests...")
        
        # Adjusted figure size for better proportions
        fig, axes = plt.subplots(2, 3, figsize=(16, 9))
        
        # Config labels for plots (shortened for heatmaps)
        configs = ['UCB_with_FTRL', 'UCB_without_FTRL', 'Thompson_with_FTRL', 'Thompson_without_FTRL']
        config_labels = ['UCB\nw/ FTRL', 'UCB\nw/o FTRL', 'TS\nw/ FTRL', 'TS\nw/o FTRL']
        bar_labels = ['UCB w/ FTRL', 'UCB w/o FTRL', 'TS w/ FTRL', 'TS w/o FTRL']
        
        for row_idx, scale in enumerate(self.scales):
            data = self._load_study_data('stress', scale)
            n_label = self._get_n_label(scale)
            
            if data is None:
                for col in range(3):
                    axes[row_idx, col].text(0.5, 0.5, f'No data for {scale}',
                                           ha='center', va='center', transform=axes[row_idx, col].transAxes)
                continue
            
            results = data.get('results', {})
            scenarios = data.get('scenarios', [])
            
            # Prepare data matrices
            mean_matrix = np.zeros((len(scenarios), len(configs)))
            std_matrix = np.zeros((len(scenarios), len(configs)))
            min_matrix = np.zeros((len(scenarios), len(configs)))
            
            for i, scenario in enumerate(scenarios):
                scenario_data = results.get(scenario, {})
                for j, config in enumerate(configs):
                    config_data = scenario_data.get(config, {})
                    mean_matrix[i, j] = config_data.get('mean_hv', 0)
                    std_matrix[i, j] = config_data.get('std_hv', 0)
                    min_matrix[i, j] = config_data.get('min_hv', 0)
            
            # Column 0: Mean HV Heatmap
            ax0 = axes[row_idx, 0]
            sns.heatmap(mean_matrix, annot=True, fmt='.3f', cmap='RdYlGn',
                       xticklabels=config_labels, yticklabels=scenarios, ax=ax0,
                       cbar_kws={'shrink': 0.7}, annot_kws={'size': 10},
                       square=False)
            ax0.set_title('Mean Hypervolume' if row_idx == 0 else '', fontweight='bold')
            ax0.set_ylabel('')  # Remove ylabel, we'll use row label instead
            
            # Column 1: Std HV Heatmap (lower is better)
            ax1 = axes[row_idx, 1]
            sns.heatmap(std_matrix, annot=True, fmt='.3f', cmap='RdYlGn_r',
                       xticklabels=config_labels, yticklabels=scenarios, ax=ax1,
                       cbar_kws={'shrink': 0.7}, annot_kws={'size': 10},
                       square=False)
            ax1.set_title('Std Dev (lower=better)' if row_idx == 0 else '', fontweight='bold')
            ax1.set_ylabel('')
            
            # Column 2: Worst-case (min) by scenario
            ax2 = axes[row_idx, 2]
            x = np.arange(len(scenarios))
            width = 0.18  # Slightly narrower bars
            
            for idx, (config, label) in enumerate(zip(configs, bar_labels)):
                offset = (idx - 1.5) * width
                color = self.COLORS.get(config, '#949494')
                ax2.bar(x + offset, min_matrix[:, idx], width, label=label,
                       color=color, alpha=0.8)
            
            ax2.set_xticks(x)
            ax2.set_xticklabels(scenarios, rotation=45, ha='right')
            ax2.set_ylabel('Min Hypervolume')
            ax2.set_title('Worst-Case Performance' if row_idx == 0 else '', fontweight='bold')
            ax2.grid(True, alpha=0.3, axis='y')
        
        # Add row labels (scale) on the left side using fig.text
        for row_idx, scale in enumerate(self.scales):
            n_label = self._get_n_label(scale)
            # Position: x=0.02 (far left), y depends on row
            y_pos = 0.72 - row_idx * 0.42  # Adjust based on row
            fig.text(0.02, y_pos, n_label, fontsize=14, fontweight='bold',
                    ha='center', va='center', rotation=90)
        
        # Add single legend outside the plots at the bottom
        handles, labels = axes[0, 2].get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center', ncol=4, fontsize=11,
                  bbox_to_anchor=(0.5, 0.1), frameon=True, fancybox=True)
        
        fig.suptitle(f'Stress Tests: {self.problem_type}', fontsize=16, fontweight='bold')
        
        # Adjust layout - leave room at left for row labels and bottom for legend
        plt.tight_layout(rect=[0.05, 0.06, 0.98, 0.96])
        
        self._save_figure(fig, f'fig3_stress_{self.problem_type}')
    
    # =========================================================================
    # FIGURE 4: REGRET ANALYSIS (2 rows × 3 cols)
    # =========================================================================
    
    def plot_regret(self):
        """
        Figure 4: Regret Analysis
        
        Layout: 2 rows (scales) × 3 cols (Cumulative Regret, Convergence, Stability)
        """
        print("\nGenerating Figure 4: Regret Analysis...")
        
        fig, axes = plt.subplots(2, 3, figsize=(14, 8))
        
        for row_idx, scale in enumerate(self.scales):
            data = self._load_study_data('regret', scale)
            n_label = self._get_n_label(scale)
            
            if data is None:
                for col in range(3):
                    axes[row_idx, col].text(0.5, 0.5, f'No data for {scale}',
                                           ha='center', va='center', transform=axes[row_idx, col].transAxes)
                    if col == 0:
                        axes[row_idx, col].annotate(n_label, xy=(-0.25, 0.5), xycoords='axes fraction',
                                    fontsize=14, fontweight='bold', ha='center', va='center',
                                    rotation=90)
                continue
            
            results = data.get('results', {})
            
            # Column 0: Cumulative Regret Curves
            ax0 = axes[row_idx, 0]
            has_regret = False
            
            for config_name, config_data in results.items():
                regret_curves = config_data.get('regret_curves', [])
                non_empty = [c for c in regret_curves if c and len(c) > 0]
                
                if non_empty:
                    has_regret = True
                    max_len = max(len(c) for c in non_empty)
                    aligned = []
                    for curve in non_empty:
                        if len(curve) < max_len:
                            curve = list(curve) + [curve[-1]] * (max_len - len(curve))
                        aligned.append(curve[:max_len])
                    
                    aligned = np.array(aligned)
                    mean_regret = np.mean(aligned, axis=0)
                    std_regret = np.std(aligned, axis=0)
                    
                    x = np.arange(len(mean_regret))
                    color = self.COLORS.get(config_name, '#949494')
                    label = config_name.replace('_', ' ')
                    
                    ax0.plot(x, mean_regret, label=label, color=color, linewidth=2)
                    ax0.fill_between(x, mean_regret - std_regret, mean_regret + std_regret,
                                    alpha=0.2, color=color)
            
            if not has_regret:
                ax0.text(0.5, 0.5, 'No regret data', ha='center', va='center',
                        transform=ax0.transAxes)
            
            ax0.set_xlabel('Iteration')
            ax0.set_ylabel('Cumulative Regret')
            ax0.set_title('Cumulative Regret' if row_idx == 0 else '')
            if has_regret:
                ax0.legend(loc='best', fontsize=10)
            ax0.grid(True, alpha=0.3)
            
            ax0.annotate(n_label, xy=(-0.25, 0.5), xycoords='axes fraction',
                        fontsize=14, fontweight='bold', ha='center', va='center',
                        rotation=90)
            
            # Column 1: Convergence Trajectories (best reward over time)
            ax1 = axes[row_idx, 1]
            has_conv = False
            
            for config_name, config_data in results.items():
                best_curves = config_data.get('best_rewards', [])
                non_empty = [c for c in best_curves if c and len(c) > 0]
                
                if non_empty:
                    has_conv = True
                    max_len = max(len(c) for c in non_empty)
                    aligned = []
                    for curve in non_empty:
                        if len(curve) < max_len:
                            curve = list(curve) + [curve[-1]] * (max_len - len(curve))
                        aligned.append(curve[:max_len])
                    
                    aligned = np.array(aligned)
                    mean_best = np.mean(aligned, axis=0)
                    std_best = np.std(aligned, axis=0)
                    
                    x = np.arange(len(mean_best))
                    color = self.COLORS.get(config_name, '#949494')
                    label = config_name.replace('_', ' ')
                    
                    ax1.plot(x, mean_best, label=label, color=color, linewidth=2)
                    ax1.fill_between(x, mean_best - std_best, mean_best + std_best,
                                    alpha=0.2, color=color)
            
            if not has_conv:
                ax1.text(0.5, 0.5, 'No convergence data', ha='center', va='center',
                        transform=ax1.transAxes)
            
            ax1.set_xlabel('Iteration')
            ax1.set_ylabel('Best Reward')
            ax1.set_title('Convergence Trajectories' if row_idx == 0 else '')
            if has_conv:
                ax1.legend(loc='best', fontsize=10)
            ax1.grid(True, alpha=0.3)
            
            # Column 2: Stability Score (Mean/Std ratio)
            ax2 = axes[row_idx, 2]
            stability_data = []
            
            for config_name, config_data in results.items():
                hvs = config_data.get('hypervolume', [])
                valid_hvs = [h for h in hvs if h and h > 0]
                
                if valid_hvs:
                    mean_hv = np.mean(valid_hvs)
                    std_hv = np.std(valid_hvs)
                    stability = mean_hv / (std_hv + 1e-10)
                    
                    algorithm = 'UCB' if 'UCB' in config_name else 'Thompson'
                    ftrl = 'Without FTRL' if 'without_FTRL' in config_name else 'With FTRL'
                    
                    stability_data.append({
                        'Algorithm': algorithm,
                        'FTRL': ftrl,
                        'Stability': min(stability, 100)  # Cap for visualization
                    })
            
            if stability_data:
                df_stab = pd.DataFrame(stability_data)
                sns.barplot(data=df_stab, x='Algorithm', y='Stability', hue='FTRL', ax=ax2,
                           palette={'With FTRL': self.COLORS['With FTRL'],
                                   'Without FTRL': self.COLORS['Without FTRL']})
                ax2.set_title('Stability (higher=better)' if row_idx == 0 else '')
                ax2.legend(loc='best', fontsize=10)
            else:
                ax2.text(0.5, 0.5, 'No stability data', ha='center', va='center',
                        transform=ax2.transAxes)
            
            ax2.grid(True, alpha=0.3, axis='y')
        
        fig.suptitle(f'Regret Analysis: {self.problem_type}', fontsize=16, fontweight='bold')
        plt.tight_layout(rect=[0.03, 0, 1, 0.96])
        
        self._save_figure(fig, f'fig4_regret_{self.problem_type}')
    
    # =========================================================================
    # PLOT ALL
    # =========================================================================
    
    def plot_all(self):
        """Generate all 4 figures."""
        print(f"\n{'='*60}")
        print("Generating all FTRL paper figures")
        print(f"{'='*60}")
        
        self.plot_rate_study()
        self.plot_variance()
        self.plot_stress_tests()
        self.plot_regret()
        
        print(f"\n{'='*60}")
        print(f"All figures saved to: {self.output_dir}")
        print(f"{'='*60}")



# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    # Example usage
    
    # Initialize plotter
    plotter = FTRLPaperPlotter(
        yaml_dir="ftrl_comprehensive_results/",  # Directory containing YAMLs
        output_dir="paper_figures/",              # Output directory for figures
        problem_type="BiTSP",                     # Dataset
        scales=['small', 'large']                 # Two scales to compare
    )
    
    # Generate all figures
    plotter.plot_all()
    
    # Or generate individual figures
    # plotter.plot_rate_study()
    # plotter.plot_variance()
    # plotter.plot_stress_tests()
    # plotter.plot_regret()
    
    # For BiKP dataset
    # plotter_bikp = FTRLPaperPlotter(
    #     yaml_dir="ftrl_comprehensive_results/",
    #     output_dir="paper_figures/",
    #     problem_type="BiKP",
    #     scales=['small', 'large']
    # )
    # plotter_bikp.plot_all()