"""
Comprehensive FTRL Ablation Study

This module provides in-depth analysis of FTRL's effect on the UCB-Hedge and 
Thompson-Hedge algorithms for multi-objective combinatorial optimization.

Key Analysis Dimensions:
1. Variance Reduction - FTRL's primary theoretical strength
2. Stress Tests - Scenarios where FTRL should excel (non-stationary, high-noise)
3. FTRL Usage Rate - 30% (default) vs 100% to isolate effect
4. Regret Analysis - Cumulative regret curves over iterations
5. Convergence Stability - Smoothness of reward trajectories
6. Worst-Case Performance - Min across seeds, not just mean

Author: Research Team
Date: 2024
"""

import os
import yaml
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any, List, Tuple, Optional, Callable
from collections import defaultdict
import torch
import warnings
import pickle
import json
from scipy import stats
from scipy.ndimage import uniform_filter1d
import sys
from pathlib import Path

import sys

# Import evaluator
from MOCO.evaluation import MOCOEvaluator

from MOCO.problems import MultiObjectiveKnapsack, BiObjectiveTSP
from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import (
    CachedAdvancedBiKPWrapper as UCBWrapper
)
from our_method_dl_Thompson_variant import (
    CachedAdvancedBiKPWrapper as ThompsonWrapper
)

# Import the existing tracking-enabled wrapper for UCB
from enhanced_wrapper_with_tracking import (
    CachedAdvancedBiKPWrapperWithTracking as UCBTrackingWrapper
)

# For Thompson, we need to create an equivalent class
# Import the base Thompson wrapper and MetricsTracker
from enhanced_wrapper_with_tracking import MetricsTracker
from our_method_dl_Thompson_variant import (
    AdvancedDecompositionWrapper as ThompsonAdvancedWrapper,
    # Fix edDecomposedGameOptUCBHedge as FixedDecomposedGameOptThompsonHedge
)

# Otherwise, create Thompson tracking wrapper
# We only need to import the wrapper - it creates the optimizer internally
from our_method_dl_Thompson_variant import (
    CachedAdvancedBiKPWrapper as ThompsonBaseWrapper
)

warnings.filterwarnings('ignore')



# Add project root to path
project_root = Path(__file__).parent
sys.path.append(str(project_root))




# ============================================================================
# ICML STYLE CONFIGURATION
# ============================================================================

def setup_icml_style():
    """Configure matplotlib for ICML paper style"""
    plt.style.use('seaborn-v0_8-paper')
    
    # Colorblind-friendly palette
    colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161', '#949494']
    sns.set_palette(sns.color_palette(colors))
    
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica'],
        'font.size': 9,
        'axes.labelsize': 10,
        'axes.titlesize': 11,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'lines.linewidth': 1.5,
        'lines.markersize': 6,
        'legend.frameon': True,
        'legend.framealpha': 0.9,
    })

setup_icml_style()


# ============================================================================
# TRACKING WRAPPER - Use existing CachedAdvancedBiKPWrapperWithTracking
# ============================================================================

# Architecture:
# - AdvancedDecompositionWrapper: High-level multi-objective handler
#   - Generates weight vectors, maintains Pareto archive
#   - Runs optimizer once per weight vector
#
# - FixedDecomposedGameOptUCBHedge: Low-level single-objective optimizer  
#   - The actual UCB-Hedge bandit algorithm
#   - Runs for max_iterations per weight vector
#
# CachedAdvancedBiKPWrapperWithTracking extends the wrapper and defines
# a tracking-enabled optimizer internally, so we only need to import the wrapper.

def get_tracking_wrapper_classes():
    """
    Get the tracking-enabled wrapper classes.
    
    Returns tuple of (UCBTrackingWrapper, ThompsonTrackingWrapper)
    
    UCB uses the existing CachedAdvancedBiKPWrapperWithTracking.
    Thompson uses CachedAdvancedBiKPWrapperWithTracking from its file 
    (if available) or we create an equivalent.
    """
    
    
    class ThompsonTrackingWrapper(ThompsonBaseWrapper):
        """
        Thompson-Sampling version with tracking.
        Extends CachedAdvancedBiKPWrapper to add MetricsTracker.
        
        Note: This is a simplified version. For full tracking parity with UCB,
        you would need to override run() to use a tracking-enabled optimizer
        (like enhanced_wrapper_with_tracking.py does for UCB).
        """
        
        def __init__(self, problem, **kwargs):
            # Tracking storage
            self.all_trackers = []
            self.aggregate_tracker = None
            
            super().__init__(problem, **kwargs)
        
        def run(self):
            """Run with basic tracking of rewards"""
            # For now, use parent's run and extract what we can
            result = super().run()
            
            # Create a simple tracker from the rewards we can access
            self.aggregate_tracker = MetricsTracker()
            
            # If parent stored rewards, use them
            if hasattr(self, 'rewards') and self.rewards:
                for i, reward in enumerate(self.rewards):
                    self.aggregate_tracker.sequential_learning_metrics[i] = {
                        'cumulative_best_reward': max(self.rewards[:i+1]),
                        'total_iteration_improvement': reward - (self.rewards[i-1] if i > 0 else reward),
                        'subproblem_improvements': []
                    }
            
            return result
        
        def get_tracker(self):
            return self.aggregate_tracker
        
        def get_all_trackers(self):
            return self.all_trackers
    
    return UCBTrackingWrapper, ThompsonTrackingWrapper


def extract_tracking_data(wrapper) -> Dict[str, Any]:
    """
    Extract iteration-level tracking data from wrapper for regret analysis.
    
    Works with any wrapper that has get_tracker() returning a MetricsTracker.
    
    Returns:
        Dict with iteration_rewards, best_rewards, regret_curve, stability_metrics
    """
    tracker = wrapper.get_tracker() if hasattr(wrapper, 'get_tracker') else None
    
    if tracker is None or not hasattr(tracker, 'sequential_learning_metrics'):
        return {
            'iteration_rewards': [],
            'best_rewards': [],
            'regret_curve': [],
            'stability_metrics': {}
        }
    
    # Extract from sequential_learning_metrics
    iterations = sorted(tracker.sequential_learning_metrics.keys())
    
    if not iterations:
        return {
            'iteration_rewards': [],
            'best_rewards': [],
            'regret_curve': [],
            'stability_metrics': {}
        }
    
    best_rewards = []
    iteration_improvements = []
    
    for it in iterations:
        metrics = tracker.sequential_learning_metrics[it]
        best_rewards.append(metrics.get('cumulative_best_reward', 0))
        iteration_improvements.append(metrics.get('total_iteration_improvement', 0))
    
    # Compute regret curve (cumulative difference from best)
    if best_rewards:
        optimal = max(best_rewards)
        regrets = [optimal - r for r in best_rewards]
        regret_curve = np.cumsum(regrets).tolist()
    else:
        regret_curve = []
    
    # Compute stability metrics
    stability_metrics = {}
    if len(best_rewards) >= 10:
        rewards_array = np.array(best_rewards)
        diffs = np.diff(rewards_array)
        
        # Smoothness: inverse variance of differences
        smoothness = 1.0 / (1.0 + np.var(diffs)) if len(diffs) > 0 else 0.0
        
        # Oscillation: sign changes
        if len(diffs) > 1:
            sign_changes = np.sum(np.diff(np.sign(diffs)) != 0)
            oscillation = sign_changes / len(diffs)
        else:
            oscillation = 0.0
        
        # Final stability: variance in last 20%
        final_portion = rewards_array[int(0.8 * len(rewards_array)):]
        final_stability = 1.0 / (1.0 + np.var(final_portion)) if len(final_portion) > 0 else 0.0
        
        stability_metrics = {
            'smoothness': smoothness,
            'oscillation': oscillation,
            'final_stability': final_stability
        }
    
    return {
        'iteration_rewards': iteration_improvements,
        'best_rewards': best_rewards,
        'regret_curve': regret_curve,
        'stability_metrics': stability_metrics
    }


# Global cache for tracking wrapper classes
_TRACKING_WRAPPER_CLASSES = None

def get_tracking_wrappers():
    """Get or create tracking wrapper classes (cached)"""
    global _TRACKING_WRAPPER_CLASSES
    if _TRACKING_WRAPPER_CLASSES is None:
        _TRACKING_WRAPPER_CLASSES = get_tracking_wrapper_classes()
    return _TRACKING_WRAPPER_CLASSES


# ============================================================================
# STRESS TEST SCENARIOS
# ============================================================================

class StressTestScenario:
    """Base class for stress test scenarios"""
    
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description
    
    def modify_problem(self, problem) -> Any:
        """Modify the problem instance for stress testing"""
        raise NotImplementedError
    
    def modify_evaluation(self, eval_fn: Callable) -> Callable:
        """Modify the evaluation function for stress testing"""
        raise NotImplementedError
    
    def reset(self):
        """Reset scenario state for a new run. Override in subclasses if needed."""
        pass  # Default no-op for stateless scenarios


class HighVarianceScenario(StressTestScenario):
    """Add Gaussian noise to rewards - tests robustness to noisy feedback"""
    
    def __init__(self, noise_std: float = 0.5):
        super().__init__(
            name=f"HighVariance_std{noise_std}",
            description=f"Gaussian noise with std={noise_std} added to rewards"
        )
        self.noise_std = noise_std
    
    def modify_evaluation(self, eval_fn: Callable) -> Callable:
        noise_std = self.noise_std
        
        def noisy_eval(solution):
            true_reward = eval_fn(solution)
            if isinstance(true_reward, (list, tuple)):
                # Multi-objective: add noise to each
                noisy = [r + np.random.normal(0, noise_std * abs(r) + 0.1) 
                        for r in true_reward]
                return type(true_reward)(noisy)
            else:
                # Single objective
                noise = np.random.normal(0, noise_std * abs(true_reward) + 0.1)
                return true_reward + noise
        
        return noisy_eval


class NonStationaryScenario(StressTestScenario):
    """
    Non-stationary rewards - objective function changes mid-run.
    Tests FTRL's ability to adapt to changing environments.
    """
    
    def __init__(self, change_points: List[int] = None, change_magnitude: float = 0.3):
        super().__init__(
            name=f"NonStationary_mag{change_magnitude}",
            description=f"Objective changes at specified iterations with magnitude {change_magnitude}"
        )
        self.change_points = change_points or [50, 100, 150]
        self.change_magnitude = change_magnitude
        self.eval_count = 0
        self.current_bias = 0.0
    
    def modify_evaluation(self, eval_fn: Callable) -> Callable:
        scenario = self
        
        def nonstationary_eval(solution):
            scenario.eval_count += 1
            
            # Check for change points
            for cp in scenario.change_points:
                if scenario.eval_count == cp:
                    # Shift the bias
                    scenario.current_bias += np.random.choice([-1, 1]) * scenario.change_magnitude
            
            true_reward = eval_fn(solution)
            
            if isinstance(true_reward, (list, tuple)):
                # Multi-objective: apply position-dependent bias
                biased = []
                for i, r in enumerate(true_reward):
                    # Alternate bias direction for different objectives
                    bias = scenario.current_bias * ((-1) ** i)
                    biased.append(r * (1 + bias))
                return type(true_reward)(biased)
            else:
                return true_reward * (1 + scenario.current_bias)
        
        return nonstationary_eval
    
    def reset(self):
        """Reset for new run"""
        self.eval_count = 0
        self.current_bias = 0.0


class AdversarialScenario(StressTestScenario):
    """
    Adversarial rewards - rewards are anti-correlated with recent choices.
    This is the worst-case scenario where FTRL's guarantees matter most.
    """
    
    def __init__(self, adversarial_strength: float = 0.2):
        super().__init__(
            name=f"Adversarial_str{adversarial_strength}",
            description=f"Rewards penalize recent successful actions (strength={adversarial_strength})"
        )
        self.adversarial_strength = adversarial_strength
        self.recent_solutions = []
        self.max_history = 20
    
    def modify_evaluation(self, eval_fn: Callable) -> Callable:
        scenario = self
        
        def adversarial_eval(solution):
            true_reward = eval_fn(solution)
            
            # Compute similarity to recent solutions
            penalty = 0.0
            if scenario.recent_solutions:
                for prev_sol in scenario.recent_solutions[-scenario.max_history:]:
                    if len(prev_sol) == len(solution):
                        # Jaccard similarity for sets, or direct comparison
                        try:
                            similarity = sum(a == b for a, b in zip(solution, prev_sol)) / len(solution)
                            penalty += similarity * scenario.adversarial_strength
                        except:
                            pass
                penalty /= min(len(scenario.recent_solutions), scenario.max_history)
            
            # Store solution
            scenario.recent_solutions.append(list(solution) if hasattr(solution, '__iter__') else solution)
            if len(scenario.recent_solutions) > scenario.max_history * 2:
                scenario.recent_solutions = scenario.recent_solutions[-scenario.max_history:]
            
            # Apply penalty
            if isinstance(true_reward, (list, tuple)):
                penalized = [r * (1 - penalty) for r in true_reward]
                return type(true_reward)(penalized)
            else:
                return true_reward * (1 - penalty)
        
        return adversarial_eval
    
    def reset(self):
        """Reset for new run"""
        self.recent_solutions = []


class NearOptimalPlateauScenario(StressTestScenario):
    """
    Many solutions with similar rewards - tests exploration efficiency.
    FTRL should help systematic exploration in flat regions.
    """
    
    def __init__(self, plateau_threshold: float = 0.9, plateau_noise: float = 0.05):
        super().__init__(
            name=f"Plateau_thr{plateau_threshold}",
            description=f"Rewards above {plateau_threshold} quantile compressed to similar values"
        )
        self.plateau_threshold = plateau_threshold
        self.plateau_noise = plateau_noise
        self.reward_history = []
    
    def modify_evaluation(self, eval_fn: Callable) -> Callable:
        scenario = self
        
        def plateau_eval(solution):
            true_reward = eval_fn(solution)
            
            # Track rewards to estimate distribution
            if isinstance(true_reward, (list, tuple)):
                scenario.reward_history.append(sum(true_reward))
            else:
                scenario.reward_history.append(true_reward)
            
            # After warmup, compress high rewards
            if len(scenario.reward_history) > 20:
                threshold = np.percentile(scenario.reward_history, 
                                         scenario.plateau_threshold * 100)
                
                if isinstance(true_reward, (list, tuple)):
                    compressed = []
                    for r in true_reward:
                        if r > threshold:
                            # Compress to near-threshold with small noise
                            r = threshold + scenario.plateau_noise * np.random.random()
                        compressed.append(r)
                    return type(true_reward)(compressed)
                else:
                    if true_reward > threshold:
                        return threshold + scenario.plateau_noise * np.random.random()
            
            return true_reward
        
        return plateau_eval
    
    def reset(self):
        """Reset for new run"""
        self.reward_history = []


# ============================================================================
# COMPREHENSIVE FTRL ABLATION STUDY
# ============================================================================

class FTRLComprehensiveStudy:
    """
    Comprehensive study of FTRL's effect across multiple dimensions.
    """
    NUM_PER_EXP = 1
    # Default algorithm parameters (shared across all studies)
    DEFAULT_BASE_PARAMS = {
        'learning_rate': 0.5,
        'initial_temperature': 1.0,
        'temp_decay': 0.98,
        'hybrid_ratio': 0.5,
        'adaptive_hybrid': True,
        'max_iterations': 150,
        'decomposition_size': 25,  # Will be adjusted per problem size
        'overlap': 11,              # Will be adjusted per problem size
        'use_lagrangian': True,
        'use_ftrl': True,
        'dual_step_size': 1.0,
        'n_weight_vectors': 25,
        'nb_rounds': 10, # tsp20: 5
        'patience': 50 # tsp20: 30
    }
    
    # Problem configurations (centralized)
    PROBLEM_CONFIGS = {
        'BiKP': {
            'class_name': 'MultiObjectiveKnapsack',
            'small': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5},
            'medium': {'n_items': 100, 'n_objectives': 2, 'capacity': 25.0},
            'large': {'n_items': 200, 'n_objectives': 2, 'capacity': 50.0},
            'ref': lambda n: (5, 5) if n <= 50 else (20, 20) if n <= 100 else (30, 30),
            'size_key': 'n_items'
        },
        'BiTSP': {
            'class_name': 'BiObjectiveTSP',
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100},
            'ref': lambda n: (20, 20) if n <= 20 else (35, 35) if n <= 50 else (65, 65),
            'size_key': 'n_cities'
        },
        'TriTSP': {
            'class_name': 'TriObjectiveTSP',
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100},
            'ref': lambda n: (20, 20, 20) if n <= 20 else (35, 35, 35) if n <= 50 else (65, 65, 65),
            'size_key': 'n_cities'
        }
    }
    
    def __init__(
        self,
        output_dir: str = 'ftrl_study_results',
        num_seeds: int = 30,
        verbose: bool = True
    ):
        self.output_dir = output_dir
        self.num_seeds = num_seeds
        self.verbose = verbose
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Results storage
        self.results = defaultdict(lambda: defaultdict(list))
        self.tracking_data = defaultdict(lambda: defaultdict(list))
    
    def _get_base_params(self, problem_size: int) -> Dict[str, Any]:
        """Get base parameters adjusted for problem size"""
        params = self.DEFAULT_BASE_PARAMS.copy()
        params['decomposition_size'] = max(15, problem_size // 5)
        params['overlap'] = max(3, problem_size // 15)
        return params
    
    def _get_problem_config(self, problem_type: str, problem_size: str = 'medium'):
        """
        Get problem configuration.
        
        Returns:
            tuple: (problem_class, problem_params, ref_point, actual_size)
        """
        from MOCO.problems import BiObjectiveTSP, MultiObjectiveKnapsack, TriObjectiveTSP
        
        class_map = {
            'MultiObjectiveKnapsack': MultiObjectiveKnapsack,
            'BiObjectiveTSP': BiObjectiveTSP,
            'TriObjectiveTSP': TriObjectiveTSP
        }
        
        config = self.PROBLEM_CONFIGS.get(problem_type, self.PROBLEM_CONFIGS['BiKP'])
        problem_class = class_map[config['class_name']]
        problem_params = config[problem_size]
        actual_size = problem_params.get(config['size_key'], 50)
        ref_point = config['ref'](actual_size)
        
        return problem_class, problem_params, ref_point, actual_size
    
    def log(self, message: str):
        if self.verbose:
            print(message)
    
    def _compute_hypervolume_simple(self, objectives: List, ref_point: Tuple) -> float:
        """
        Simple hypervolume calculation for 2D/3D cases.
        For minimization problems, assumes objectives should be minimized.
        """
        if not objectives:
            return 0.0
        
        objectives = np.array(objectives)
        ref_point = np.array(ref_point)
        
        # Filter dominated points and points beyond reference
        valid = []
        for obj in objectives:
            # Check if point is within reference bounds (for minimization)
            if all(obj[i] < ref_point[i] for i in range(len(ref_point))):
                valid.append(obj)
        
        if not valid:
            return 0.0
        
        valid = np.array(valid)
        
        if len(ref_point) == 2:
            # 2D hypervolume: sort by first objective, compute rectangular areas
            sorted_idx = np.argsort(valid[:, 0])
            sorted_points = valid[sorted_idx]
            
            hv = 0.0
            prev_x = 0.0  # Starting from origin or min point
            
            # Find non-dominated front
            front = []
            min_y = float('inf')
            for pt in sorted_points:
                if pt[1] < min_y:
                    front.append(pt)
                    min_y = pt[1]
            
            if not front:
                return 0.0
            
            front = np.array(front)
            
            # Compute hypervolume
            for i, pt in enumerate(front):
                if i == 0:
                    width = ref_point[0] - pt[0]
                else:
                    width = front[i-1][0] - pt[0]
                height = ref_point[1] - pt[1]
                hv += max(0, width) * max(0, height)
            
            # Add final rectangle
            if len(front) > 0:
                hv += (ref_point[0] - front[-1][0]) * (ref_point[1] - front[-1][1])
            
            return max(0.0, hv)
        
        else:
            # For higher dimensions, use Monte Carlo approximation
            n_samples = 10000
            
            # Sample uniformly in the reference box
            samples = np.random.uniform(
                low=np.min(valid, axis=0),
                high=ref_point,
                size=(n_samples, len(ref_point))
            )
            
            # Count samples dominated by at least one point
            dominated = 0
            for sample in samples:
                for obj in valid:
                    if all(obj[i] <= sample[i] for i in range(len(ref_point))):
                        dominated += 1
                        break
            
            # Estimate hypervolume
            box_volume = np.prod(ref_point - np.min(valid, axis=0))
            hv = (dominated / n_samples) * box_volume
            
            return max(0.0, hv)
    
    def run_single_experiment(
        self,
        wrapper_class,
        problem_class,
        problem_params: Dict,
        algorithm_params: Dict,
        ref_point: Tuple,
        scenario: StressTestScenario = None,
        seed: int = 0,
        use_tracking: bool = True
    ) -> Dict[str, Any]:
        """
        Run a single experiment and return comprehensive metrics.
        Uses MOCOEvaluator for standard runs.
        """
        import traceback as tb
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        try:
            from MOCO.evaluation import MOCOEvaluator
            
            results_dir = 'results_MOCO_April'
            os.makedirs(results_dir, exist_ok=True)
            
            evaluator = MOCOEvaluator(reference_point=ref_point)
            
            # Handle stress test scenario by wrapping the wrapper class
            if scenario is not None:
                scenario.reset()
                original_wrapper_class = wrapper_class
                _scenario = scenario
                
                # Create a wrapper that modifies the evaluator
                class StressTestWrapper(original_wrapper_class):
                    """Wrapper that applies stress test to evaluation"""
                    
                    def _create_evaluator(self, weights):
                        """Override to apply stress scenario to evaluator"""
                        # Get original evaluator
                        original_eval = super()._create_evaluator(weights)
                        # Wrap with stress scenario
                        return _scenario.modify_evaluation(original_eval)
                
                wrapper_class_to_use = StressTestWrapper
            else:
                wrapper_class_to_use = wrapper_class
            
            start_time = time.time()
            result = evaluator.evaluate_algorithm(
                algorithm_class=wrapper_class_to_use,
                problem_class=problem_class,
                algorithm_name="FTRL_Study",
                parameters=algorithm_params,
                problem_params=problem_params,
                num_runs=self.NUM_PER_EXP
            )
            runtime = time.time() - start_time
            
            hv = result.hypervolume if hasattr(result, 'hypervolume') else 0.0
            num_solutions = result.num_nondominated if hasattr(result, 'num_nondominated') else 0
            objectives = result.objectives if hasattr(result, 'objectives') else []
            
            return {
                'hypervolume': hv,
                'runtime': runtime,
                'num_solutions': num_solutions,
                'objectives': objectives,
                'iteration_rewards': [],
                'best_rewards': [],
                'regret_curve': [],
                'stability_metrics': {},
                'success': True
            }
            
        except Exception as e:
            self.log(f"    Error in seed {seed}: {e}")
            tb.print_exc()
            return {
                'hypervolume': 0.0,
                'runtime': 0.0,
                'num_solutions': 0,
                'objectives': [],
                'iteration_rewards': [],
                'best_rewards': [],
                'regret_curve': [],
                'stability_metrics': {},
                'success': False,
                'error': str(e)
            }
    
    def run_direct_experiment(
        self,
        wrapper_class,
        problem_class,
        problem_params: Dict,
        algorithm_params: Dict,
        ref_point: Tuple,
        scenario: StressTestScenario = None,
        seed: int = 0,
        use_tracking: bool = True
    ) -> Dict[str, Any]:
        """
        Run experiment directly without MOCOEvaluator.
        This allows capturing per-iteration data for regret analysis.
        
        Uses CachedAdvancedBiKPWrapperWithTracking from enhanced_wrapper_with_tracking.py
        """
        import traceback as tb
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        try:
            # Create problem instance
            problem = problem_class(**problem_params)
            
            # Apply stress test scenario if provided
            if scenario is not None:
                scenario.reset()
                if hasattr(problem, 'evaluate'):
                    original_eval = problem.evaluate
                    problem.evaluate = scenario.modify_evaluation(original_eval)
            
            # Determine wrapper to use
            if use_tracking:
                UCBTrackingWrapper, ThompsonTrackingWrapper = get_tracking_wrappers()
                wrapper_name = wrapper_class.__name__
                if 'Thompson' in wrapper_name or 'thompson' in wrapper_name.lower():
                    actual_wrapper_class = ThompsonTrackingWrapper
                else:
                    actual_wrapper_class = UCBTrackingWrapper
            else:
                actual_wrapper_class = wrapper_class
            
            # Run optimization
            start_time = time.time()
            wrapper = actual_wrapper_class(problem, **algorithm_params)
            result_data = wrapper.run()
            runtime = time.time() - start_time
            
            # Extract tracking data using the helper function
            tracking = extract_tracking_data(wrapper)
            iteration_rewards = tracking.get('iteration_rewards', [])
            best_rewards = tracking.get('best_rewards', [])
            regret_curve = tracking.get('regret_curve', [])
            stability_metrics = tracking.get('stability_metrics', {})
            
            # Compute hypervolume from results
            hv = 0.0
            num_solutions = 0
            objectives = []
            
            if isinstance(result_data, list) and len(result_data) > 0:
                for item in result_data:
                    if isinstance(item, (list, tuple)) and len(item) >= 2:
                        obj = item[1] if isinstance(item[1], (list, tuple)) else item
                        if isinstance(obj, (list, tuple, np.ndarray)):
                            objectives.append(list(obj))
                
                num_solutions = len(result_data)
                
                if objectives:
                    hv = self._compute_hypervolume_simple(objectives, ref_point)
            
            return {
                'hypervolume': hv,
                'runtime': runtime,
                'num_solutions': num_solutions,
                'objectives': objectives,
                'iteration_rewards': iteration_rewards,
                'best_rewards': best_rewards,
                'regret_curve': regret_curve,
                'stability_metrics': stability_metrics,
                'success': True
            }
            
        except Exception as e:
            self.log(f"    Error in seed {seed}: {e}")
            tb.print_exc()
            return {
                'hypervolume': 0.0,
                'runtime': 0.0,
                'num_solutions': 0,
                'objectives': [],
                'iteration_rewards': [],
                'best_rewards': [],
                'regret_curve': [],
                'stability_metrics': {},
                'success': False,
                'error': str(e)
            }
    
    def run_ftrl_rate_study(
        self,
        problem_type: str = 'BiKP',
        problem_size: str = 'medium',
        ftrl_rates: List[float] = [0.0, 0.3, 0.5, 0.7, 1.0]
    ) -> Dict[str, Any]:
        """
        Study 1: Effect of FTRL usage rate (0% to 100%)
        
        This isolates FTRL's effect by varying how often it's used.
        """
        self.log("\n" + "="*80)
        self.log("STUDY 1: FTRL USAGE RATE ANALYSIS")
        self.log("="*80)
        
        # Import wrappers
        from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import (
            CachedAdvancedBiKPWrapper as UCBWrapper
        )
        from our_method_dl_Thompson_variant import (
            CachedAdvancedBiKPWrapper as ThompsonWrapper
        )
        
        # Get problem configuration using centralized method
        problem_class, problem_params, ref_point, actual_size = self._get_problem_config(
            problem_type, problem_size
        )
        
        self.log(f"Problem: {problem_type} ({problem_size}, n={actual_size})")
        self.log(f"FTRL rates to test: {ftrl_rates}")
        self.log(f"Seeds per configuration: {self.num_seeds}")
        
        results = defaultdict(lambda: defaultdict(list))
        
        # Get base algorithm parameters adjusted for problem size
        base_params = self._get_base_params(actual_size)
        
        # Test each FTRL rate
        for ftrl_rate in ftrl_rates:
            self.log(f"\n--- FTRL Rate: {ftrl_rate*100:.0f}% ---")
            
            for alg_name, wrapper_class in [('UCB', UCBWrapper), ('Thompson', ThompsonWrapper)]:
                self.log(f"  Testing {alg_name}...")
                
                params = base_params.copy()
                params['ftrl_rate'] = ftrl_rate
                params['use_ftrl'] = ftrl_rate > 0
                
                for seed in range(self.num_seeds):
                    result = self.run_single_experiment(
                        wrapper_class=wrapper_class,
                        problem_class=problem_class,
                        problem_params=problem_params,
                        algorithm_params=params,
                        ref_point=ref_point,
                        seed=seed
                    )
                    
                    config_key = f"{alg_name}_rate{ftrl_rate}"
                    results[config_key]['hypervolume'].append(result['hypervolume'])
                    results[config_key]['runtime'].append(result['runtime'])
                    results[config_key]['num_solutions'].append(result['num_solutions'])
                    results[config_key]['regret_curves'].append(result['regret_curve'])
                    results[config_key]['stability'].append(result['stability_metrics'])
                    
                    if (seed + 1) % 10 == 0:
                        self.log(f"    Completed {seed + 1}/{self.num_seeds} seeds")
        
        # Save and plot results
        self._save_ftrl_rate_results(results, problem_type, problem_size, ftrl_rates)
        self._plot_ftrl_rate_study(results, problem_type, problem_size, ftrl_rates, actual_size)
        
        return dict(results)
    
    def run_variance_analysis(
        self,
        problem_type: str = 'BiKP',
        problem_size: str = 'medium'
    ) -> Dict[str, Any]:
        """
        Study 2: Variance Reduction Analysis
        
        Focus on FTRL's primary strength: reducing variance in outcomes.
        """
        self.log("\n" + "="*80)
        self.log("STUDY 2: VARIANCE REDUCTION ANALYSIS")
        self.log("="*80)
        
        from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import (
            CachedAdvancedBiKPWrapper as UCBWrapper
        )
        from our_method_dl_Thompson_variant import (
            CachedAdvancedBiKPWrapper as ThompsonWrapper
        )
        
        # Get problem configuration using centralized method
        problem_class, problem_params, ref_point, actual_size = self._get_problem_config(
            problem_type, problem_size
        )
        
        self.log(f"Problem: {problem_type} ({problem_size}, n={actual_size})")
        self.log(f"Running {self.num_seeds} seeds per configuration")
        
        results = {
            'UCB_with_FTRL': defaultdict(list),
            'UCB_without_FTRL': defaultdict(list),
            'Thompson_with_FTRL': defaultdict(list),
            'Thompson_without_FTRL': defaultdict(list)
        }
        
        base_params = self._get_base_params(actual_size)
        
        configs = [
            ('UCB_with_FTRL', UCBWrapper, True, 1.0),
            ('UCB_without_FTRL', UCBWrapper, False, 0.0),
            ('Thompson_with_FTRL', ThompsonWrapper, True, 1.0),
            ('Thompson_without_FTRL', ThompsonWrapper, False, 0.0)
        ]
        
        for config_name, wrapper_class, use_ftrl, ftrl_rate in configs:
            self.log(f"\n--- {config_name} ---")
            
            params = base_params.copy()
            params['use_ftrl'] = use_ftrl
            params['ftrl_rate'] = ftrl_rate
            
            for seed in range(self.num_seeds):
                result = self.run_single_experiment(
                    wrapper_class=wrapper_class,
                    problem_class=problem_class,
                    problem_params=problem_params,
                    algorithm_params=params,
                    ref_point=ref_point,
                    seed=seed
                )
                
                results[config_name]['hypervolume'].append(result['hypervolume'])
                results[config_name]['runtime'].append(result['runtime'])
                results[config_name]['stability'].append(result['stability_metrics'])
                
                if (seed + 1) % 10 == 0:
                    self.log(f"  Completed {seed + 1}/{self.num_seeds} seeds")
        
        # Compute variance statistics
        variance_stats = self._compute_variance_statistics(results)
        
        # Save and plot
        self._save_variance_results(results, variance_stats, problem_type, problem_size)
        self._plot_variance_analysis(results, variance_stats, problem_type, actual_size)
        
        return {'results': dict(results), 'variance_stats': variance_stats}
    
    def run_stress_tests(
        self,
        problem_type: str = 'BiKP',
        problem_size: str = 'medium'
    ) -> Dict[str, Any]:
        """
        Study 3: Stress Test Suite
        
        Test FTRL under adversarial conditions where it should theoretically excel.
        """
        self.log("\n" + "="*80)
        self.log("STUDY 3: STRESS TEST SUITE")
        self.log("="*80)
        
        from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import (
            CachedAdvancedBiKPWrapper as UCBWrapper
        )
        from our_method_dl_Thompson_variant import (
            CachedAdvancedBiKPWrapper as ThompsonWrapper
        )
        
        # Get problem configuration using centralized method
        problem_class, problem_params, ref_point, actual_size = self._get_problem_config(
            problem_type, problem_size
        )
        
        # Define stress test scenarios
        scenarios = [
            None,  # Baseline (no stress)
            HighVarianceScenario(noise_std=0.3),
            HighVarianceScenario(noise_std=0.7),
            NonStationaryScenario(change_points=[40, 80, 120], change_magnitude=0.2),
            NonStationaryScenario(change_points=[40, 80, 120], change_magnitude=0.5),
            AdversarialScenario(adversarial_strength=0.15),
            NearOptimalPlateauScenario(plateau_threshold=0.85)
        ]
        
        scenario_names = [
            'Baseline',
            'HighVar_0.3',
            'HighVar_0.7',
            'NonStat_0.2',
            'NonStat_0.5',
            'Adversarial',
            'Plateau'
        ]
        
        self.log(f"Problem: {problem_type} ({problem_size}, n={actual_size})")
        self.log(f"Scenarios: {scenario_names}")
        self.log(f"Seeds per configuration: {self.num_seeds}")
        
        results = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        
        base_params = self._get_base_params(actual_size)
        
        for scenario, scenario_name in zip(scenarios, scenario_names):
            self.log(f"\n=== Scenario: {scenario_name} ===")
            
            for alg_name, wrapper_class in [('UCB', UCBWrapper), ('Thompson', ThompsonWrapper)]:
                for use_ftrl, ftrl_label in [(True, 'with_FTRL'), (False, 'without_FTRL')]:
                    config_key = f"{alg_name}_{ftrl_label}"
                    self.log(f"  Testing {config_key}...")
                    
                    params = base_params.copy()
                    params['use_ftrl'] = use_ftrl
                    params['ftrl_rate'] = 1.0 if use_ftrl else 0.0
                    
                    for seed in range(self.num_seeds):
                        result = self.run_single_experiment(
                            wrapper_class=wrapper_class,
                            problem_class=problem_class,
                            problem_params=problem_params,
                            algorithm_params=params,
                            ref_point=ref_point,
                            scenario=scenario,
                            seed=seed
                        )
                        
                        results[scenario_name][config_key]['hypervolume'].append(result['hypervolume'])
                        results[scenario_name][config_key]['runtime'].append(result['runtime'])
        
        # Analyze and plot stress test results
        self._save_stress_test_results(results, problem_type, problem_size, scenario_names)
        self._plot_stress_tests(results, problem_type, scenario_names, actual_size)
        
        return dict(results)
    
    def run_regret_analysis(
        self,
        problem_type: str = 'BiKP',
        problem_size: str = 'medium',
        max_iterations: int = 200
    ) -> Dict[str, Any]:
        """
        Study 4: Regret Curve Analysis
        
        Track cumulative regret over iterations to understand learning dynamics.
        Uses direct experiment execution to capture per-iteration data.
        """
        self.log("\n" + "="*80)
        self.log("STUDY 4: REGRET CURVE ANALYSIS")
        self.log("="*80)
        
        # Get problem configuration using centralized method
        problem_class, problem_params, ref_point, actual_size = self._get_problem_config(
            problem_type, problem_size
        )
        
        self.log(f"Problem: {problem_type} ({problem_size}, n={actual_size})")
        self.log(f"Max iterations: {max_iterations}")
        self.log("Using direct experiment execution for iteration tracking")
        
        results = defaultdict(lambda: defaultdict(list))
        
        base_params = self._get_base_params(actual_size)
        base_params['max_iterations'] = max_iterations
        base_params['patience'] = 50  # Higher patience for regret analysis
        
        configs = [
            ('UCB_with_FTRL', UCBWrapper, True),
            ('UCB_without_FTRL', UCBWrapper, False),
            ('Thompson_with_FTRL', ThompsonWrapper, True),
            ('Thompson_without_FTRL', ThompsonWrapper, False)
        ]
        
        # Use fewer seeds for regret analysis (it's more expensive)
        regret_seeds = min(15, self.num_seeds)
        
        for config_name, wrapper_class, use_ftrl in configs:
            self.log(f"\n--- {config_name} ---")
            
            params = base_params.copy()
            params['use_ftrl'] = use_ftrl
            params['ftrl_rate'] = 1.0 if use_ftrl else 0.0
            
            for seed in range(regret_seeds):
                # Use run_direct_experiment to capture iteration data
                result = self.run_direct_experiment(
                    wrapper_class=wrapper_class,
                    problem_class=problem_class,
                    problem_params=problem_params,
                    algorithm_params=params,
                    ref_point=ref_point,
                    seed=seed,
                    use_tracking=True
                )
                
                results[config_name]['hypervolume'].append(result['hypervolume'])
                results[config_name]['regret_curves'].append(result['regret_curve'])
                results[config_name]['best_rewards'].append(result['best_rewards'])
                results[config_name]['iteration_rewards'].append(result['iteration_rewards'])
                results[config_name]['stability'].append(result['stability_metrics'])
                
                # Log progress with tracking info
                n_tracked = len(result['iteration_rewards'])
                if (seed + 1) % 5 == 0:
                    self.log(f"  Completed {seed + 1}/{regret_seeds} seeds (tracked {n_tracked} iterations)")
        
        # Plot regret analysis
        self._plot_regret_analysis(results, problem_type, actual_size)

        # Save results 
        self._save_regret_results(dict(results), problem_type, problem_size)
        
        return dict(results)
    
    # =========================================================================
    # HELPER METHODS
    # =========================================================================
    
    def _compute_variance_statistics(self, results: Dict) -> Dict[str, Any]:
        """Compute comprehensive variance statistics"""
        stats = {}
        
        for config_name, data in results.items():
            hvs = np.array(data['hypervolume'])
            
            # Skip if no valid data
            if len(hvs) == 0 or np.all(hvs == 0):
                stats[config_name] = {
                    'mean': 0.0,
                    'std': 0.0,
                    'var': 0.0,
                    'cv': float('inf'),
                    'min': 0.0,
                    'max': 0.0,
                    'q25': 0.0,
                    'q75': 0.0,
                    'iqr': 0.0,
                    'range': 0.0
                }
                continue
            
            mean_val = float(np.mean(hvs))
            
            stats[config_name] = {
                'mean': mean_val,
                'std': float(np.std(hvs)),
                'var': float(np.var(hvs)),
                'cv': float(np.std(hvs) / mean_val) if mean_val > 1e-10 else float('inf'),
                'min': float(np.min(hvs)),
                'max': float(np.max(hvs)),
                'q25': float(np.percentile(hvs, 25)),
                'q75': float(np.percentile(hvs, 75)),
                'iqr': float(np.percentile(hvs, 75) - np.percentile(hvs, 25)),
                'range': float(np.max(hvs) - np.min(hvs))
            }
        
        # Compute relative variance reduction
        for alg in ['UCB', 'Thompson']:
            with_ftrl = stats.get(f'{alg}_with_FTRL', {})
            without_ftrl = stats.get(f'{alg}_without_FTRL', {})
            
            if with_ftrl and without_ftrl:
                # Avoid division by zero
                if without_ftrl.get('var', 0) > 1e-10:
                    var_reduction = (without_ftrl['var'] - with_ftrl['var']) / without_ftrl['var'] * 100
                else:
                    # If variance is essentially zero, report 0% reduction
                    var_reduction = 0.0
                stats[f'{alg}_variance_reduction_%'] = var_reduction
        
        return stats
    
    def _save_ftrl_rate_results(self, results, problem_type, problem_size, ftrl_rates):
        """Save FTRL rate study results"""
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        
        # Convert to serializable format
        save_data = {
            'study': 'ftrl_rate',
            'problem_type': problem_type,
            'problem_size': problem_size,
            'ftrl_rates': ftrl_rates,
            'num_seeds': self.num_seeds,
            'results': {}
        }
        
        for config_key, data in results.items():
            hvs = data['hypervolume']
            if hvs and len(hvs) > 0:
                save_data['results'][config_key] = {
                    'hypervolume': hvs,
                    'runtime': data['runtime'],
                    'num_solutions': data['num_solutions'],
                    'mean_hv': float(np.mean(hvs)),
                    'std_hv': float(np.std(hvs)),
                    'min_hv': float(np.min(hvs)),
                    'max_hv': float(np.max(hvs))
                }
            else:
                save_data['results'][config_key] = {
                    'hypervolume': [],
                    'runtime': data.get('runtime', []),
                    'num_solutions': data.get('num_solutions', []),
                    'mean_hv': 0.0,
                    'std_hv': 0.0,
                    'min_hv': 0.0,
                    'max_hv': 0.0
                }
        
        filename = os.path.join(
            self.output_dir,
            f'ftrl_rate_study_{problem_type}_{problem_size}_{timestamp}.yaml'
        )
        
        with open(filename, 'w') as f:
            yaml.dump(save_data, f, default_flow_style=False)
        
        self.log(f"\nResults saved to: {filename}")
    
    def _save_variance_results(self, results, variance_stats, problem_type, problem_size):
        """Save variance analysis results"""
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        
        save_data = {
            'study': 'variance_analysis',
            'problem_type': problem_type,
            'num_seeds': self.num_seeds,
            'variance_statistics': variance_stats,
            'raw_results': {k: dict(v) for k, v in results.items()}
        }
        
        filename = os.path.join(
            self.output_dir,
            f'variance_analysis_{problem_type}_{problem_size}_{timestamp}.yaml'  # Add problem_size
        )
        
        with open(filename, 'w') as f:
            yaml.dump(save_data, f, default_flow_style=False)
        
        self.log(f"\nVariance results saved to: {filename}")
    
    def _save_stress_test_results(self, results, problem_type, problem_size, scenario_names):
        """Save stress test results"""
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        
        save_data = {
            'study': 'stress_tests',
            'problem_type': problem_type,
            'scenarios': scenario_names,
            'num_seeds': self.num_seeds,
            'results': {}
        }
        
        for scenario_name in scenario_names:
            save_data['results'][scenario_name] = {}
            for config_key, data in results[scenario_name].items():
                hvs = data['hypervolume']
                if hvs and len(hvs) > 0:
                    save_data['results'][scenario_name][config_key] = {
                        'hypervolume': hvs,
                        'mean_hv': float(np.mean(hvs)),
                        'std_hv': float(np.std(hvs)),
                        'min_hv': float(np.min(hvs))
                    }
                else:
                    save_data['results'][scenario_name][config_key] = {
                        'hypervolume': [],
                        'mean_hv': 0.0,
                        'std_hv': 0.0,
                        'min_hv': 0.0
                    }
        
        filename = os.path.join(
            self.output_dir,
            f'stress_tests_{problem_type}_{problem_size}_{timestamp}.yaml'  # Add problem_size
        )
        
        with open(filename, 'w') as f:
            yaml.dump(save_data, f, default_flow_style=False)
        
        self.log(f"\nStress test results saved to: {filename}")
    
    def _save_regret_results(self, results: Dict, problem_type: str, problem_size: str):
        """Save regret analysis results to YAML."""
        import time
        
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        
        save_data = {
            'study': 'regret_analysis',
            'problem_type': problem_type,
            'problem_size': problem_size,
            'num_seeds': self.num_seeds,
            'results': {}
        }
        
        for config_name, data in results.items():
            hvs = data.get('hypervolume', [])
            regret_curves = data.get('regret_curves', [])
            best_rewards = data.get('best_rewards', [])
            iteration_rewards = data.get('iteration_rewards', [])
            stability = data.get('stability', [])
            
            def to_list(x):
                if isinstance(x, np.ndarray):
                    return x.tolist()
                elif isinstance(x, list):
                    return [to_list(item) if isinstance(item, (list, np.ndarray)) else item for item in x]
                return x
            
            save_data['results'][config_name] = {
                'hypervolume': to_list(hvs),
                'regret_curves': to_list(regret_curves),
                'best_rewards': to_list(best_rewards),
                'iteration_rewards': to_list(iteration_rewards),
                'stability': to_list(stability),
                'mean_hv': float(np.mean(hvs)) if hvs else 0.0,
                'std_hv': float(np.std(hvs)) if hvs else 0.0,
                'min_hv': float(np.min(hvs)) if hvs else 0.0,
                'max_hv': float(np.max(hvs)) if hvs else 0.0
            }
        
        filename = os.path.join(
            self.output_dir,
            f'regret_analysis_{problem_type}_{problem_size}_{timestamp}.yaml'
        )
        
        with open(filename, 'w') as f:
            yaml.dump(save_data, f, default_flow_style=False)
        
        self.log(f"Regret results saved to: {filename}")
        return filename

    # =========================================================================
    # PLOTTING METHODS
    # =========================================================================
    
    def _plot_ftrl_rate_study(self, results, problem_type, problem_size, ftrl_rates, actual_size):
        """Plot FTRL rate study results"""
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Prepare data
        ucb_data = []
        thompson_data = []
        
        for rate in ftrl_rates:
            for alg_name in ['UCB', 'Thompson']:
                config_key = f"{alg_name}_rate{rate}"
                if config_key in results:
                    hvs = results[config_key]['hypervolume']
                    for hv in hvs:
                        row = {
                            'FTRL Rate': f"{rate*100:.0f}%",
                            'FTRL Rate Value': rate,
                            'Hypervolume': hv,
                            'Algorithm': alg_name
                        }
                        if alg_name == 'UCB':
                            ucb_data.append(row)
                        else:
                            thompson_data.append(row)
        
        df_ucb = pd.DataFrame(ucb_data)
        df_thompson = pd.DataFrame(thompson_data)
        df_all = pd.DataFrame(ucb_data + thompson_data)
        
        # Plot 1: Line plot of mean HV vs FTRL rate
        ax1 = axes[0, 0]
        for alg_name, df in [('UCB', df_ucb), ('Thompson', df_thompson)]:
            if len(df) > 0:
                grouped = df.groupby('FTRL Rate Value')['Hypervolume'].agg(['mean', 'std'])
                ax1.errorbar(grouped.index * 100, grouped['mean'], yerr=grouped['std'],
                           marker='o', label=alg_name, capsize=3, linewidth=2)
        ax1.set_xlabel('FTRL Usage Rate (%)')
        ax1.set_ylabel('Hypervolume')
        ax1.set_title('Mean Hypervolume vs FTRL Rate', fontweight='bold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Box plot comparison
        ax2 = axes[0, 1]
        if len(df_all) > 0:
            sns.boxplot(data=df_all, x='FTRL Rate', y='Hypervolume', hue='Algorithm', ax=ax2)
            ax2.set_title('Hypervolume Distribution by FTRL Rate', fontweight='bold')
            ax2.legend(title='Algorithm')
        
        # Plot 3: Variance (std) vs FTRL rate
        ax3 = axes[1, 0]
        for alg_name, df in [('UCB', df_ucb), ('Thompson', df_thompson)]:
            if len(df) > 0:
                grouped = df.groupby('FTRL Rate Value')['Hypervolume'].std()
                ax3.plot(grouped.index * 100, grouped.values, marker='s', 
                        label=f'{alg_name} (std)', linewidth=2)
        ax3.set_xlabel('FTRL Usage Rate (%)')
        ax3.set_ylabel('Standard Deviation')
        ax3.set_title('Variance Reduction with FTRL Rate', fontweight='bold')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Worst-case (min) performance
        ax4 = axes[1, 1]
        for alg_name, df in [('UCB', df_ucb), ('Thompson', df_thompson)]:
            if len(df) > 0:
                grouped = df.groupby('FTRL Rate Value')['Hypervolume'].min()
                ax4.plot(grouped.index * 100, grouped.values, marker='^',
                        label=f'{alg_name} (min)', linewidth=2)
        ax4.set_xlabel('FTRL Usage Rate (%)')
        ax4.set_ylabel('Minimum Hypervolume')
        ax4.set_title('Worst-Case Performance vs FTRL Rate', fontweight='bold')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        fig.suptitle(f'FTRL Rate Study: {problem_type} (n={actual_size})', 
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        filename = os.path.join(self.output_dir, f'ftrl_rate_study_{problem_type}_{timestamp}.png')
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.savefig(filename.replace('.png', '.pdf'), bbox_inches='tight')
        plt.close()
        
        self.log(f"FTRL rate plot saved to: {filename}")
    
    def _plot_variance_analysis(self, results, variance_stats, problem_type, actual_size):
        """Plot variance analysis results"""
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Prepare data
        data = []
        for config_name, runs in results.items():
            algorithm = 'UCB' if 'UCB' in config_name else 'Thompson'
            # Fix: check 'without' first to avoid substring matching issue
            ftrl = 'Without FTRL' if 'without_FTRL' in config_name else 'With FTRL'
            
            for hv in runs['hypervolume']:
                data.append({
                    'Algorithm': algorithm,
                    'FTRL': ftrl,
                    'Config': f"{algorithm}\n{ftrl}",
                    'Hypervolume': hv
                })
        
        df = pd.DataFrame(data)
        
        # Plot 1: Box plot comparison
        ax1 = axes[0, 0]
        sns.boxplot(data=df, x='Algorithm', y='Hypervolume', hue='FTRL', ax=ax1,
                   palette={'With FTRL': '#0173B2', 'Without FTRL': '#DE8F05'})
        ax1.set_title('Hypervolume Distribution', fontweight='bold')
        ax1.legend(title='FTRL Status')
        
        # Plot 2: Violin plot for distribution shape
        ax2 = axes[0, 1]
        sns.violinplot(data=df, x='Algorithm', y='Hypervolume', hue='FTRL', ax=ax2,
                      palette={'With FTRL': '#0173B2', 'Without FTRL': '#DE8F05'}, split=True)
        ax2.set_title('Distribution Shape Comparison', fontweight='bold')
        ax2.legend(title='FTRL Status')
        
        # Plot 3: Coefficient of Variation bar chart
        ax3 = axes[0, 2]
        cv_data = []
        for config_name, stats in variance_stats.items():
            if isinstance(stats, dict) and 'cv' in stats:
                algorithm = 'UCB' if 'UCB' in config_name else 'Thompson'
                ftrl = 'Without FTRL' if 'without' in config_name else 'With FTRL'
                cv_data.append({
                    'Algorithm': algorithm,
                    'FTRL': ftrl,
                    'CV': stats['cv']
                })
        
        df_cv = pd.DataFrame(cv_data)
        if len(df_cv) > 0:
            sns.barplot(data=df_cv, x='Algorithm', y='CV', hue='FTRL', ax=ax3,
                       palette={'With FTRL': '#0173B2', 'Without FTRL': '#DE8F05'})
            ax3.set_ylabel('Coefficient of Variation')
            ax3.set_title('Relative Variability (CV)', fontweight='bold')
            ax3.legend(title='FTRL Status')
        
        # Plot 4: Min/Max range comparison
        ax4 = axes[1, 0]
        range_data = []
        for config_name, stats in variance_stats.items():
            if isinstance(stats, dict) and 'min' in stats:
                algorithm = 'UCB' if 'UCB' in config_name else 'Thompson'
                ftrl = 'Without FTRL' if 'without' in config_name else 'With FTRL'
                range_data.append({
                    'Algorithm': algorithm,
                    'FTRL': ftrl,
                    'Min': stats['min'],
                    'Max': stats['max'],
                    'Mean': stats['mean']
                })
        
        df_range = pd.DataFrame(range_data)
        if len(df_range) > 0:
            x = np.arange(len(df_range))
            width = 0.35
            colors = ['#0173B2' if 'With' in r['FTRL'] else '#DE8F05' for _, r in df_range.iterrows()]
            
            for i, row in df_range.iterrows():
                ax4.bar(i, row['Mean'], width, color=colors[i], alpha=0.7,
                       label=f"{row['Algorithm']} {row['FTRL']}" if i < 4 else "")
                ax4.errorbar(i, row['Mean'], 
                           yerr=[[row['Mean'] - row['Min']], [row['Max'] - row['Mean']]],
                           fmt='none', color='black', capsize=5)
            
            ax4.set_xticks(x)
            ax4.set_xticklabels([f"{r['Algorithm']}\n{r['FTRL']}" for _, r in df_range.iterrows()],
                               rotation=45, ha='right')
            ax4.set_ylabel('Hypervolume')
            ax4.set_title('Mean with Min/Max Range', fontweight='bold')
        
        # Plot 5: IQR comparison
        ax5 = axes[1, 1]
        iqr_data = []
        for config_name, stats in variance_stats.items():
            if isinstance(stats, dict) and 'iqr' in stats:
                algorithm = 'UCB' if 'UCB' in config_name else 'Thompson'
                ftrl = 'Without FTRL' if 'without' in config_name else 'With FTRL'
                iqr_data.append({
                    'Algorithm': algorithm,
                    'FTRL': ftrl,
                    'IQR': stats['iqr']
                })
        
        df_iqr = pd.DataFrame(iqr_data)
        if len(df_iqr) > 0:
            sns.barplot(data=df_iqr, x='Algorithm', y='IQR', hue='FTRL', ax=ax5,
                       palette={'With FTRL': '#0173B2', 'Without FTRL': '#DE8F05'})
            ax5.set_ylabel('Interquartile Range')
            ax5.set_title('IQR Comparison', fontweight='bold')
            ax5.legend(title='FTRL Status')
        
        # Plot 6: Variance reduction summary
        ax6 = axes[1, 2]
        reduction_data = []
        for key, value in variance_stats.items():
            if 'variance_reduction' in key:
                alg = key.split('_')[0]
                reduction_data.append({
                    'Algorithm': alg,
                    'Variance Reduction (%)': value
                })
        
        df_reduction = pd.DataFrame(reduction_data)
        if len(df_reduction) > 0:
            colors = ['#029E73' if v >= 0 else '#CC78BC' 
                     for v in df_reduction['Variance Reduction (%)']]
            bars = ax6.bar(df_reduction['Algorithm'], 
                          df_reduction['Variance Reduction (%)'],
                          color=colors)
            ax6.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
            ax6.set_ylabel('Variance Reduction (%)')
            ax6.set_title('FTRL Variance Reduction Effect', fontweight='bold')
            
            # Add value labels
            for bar, val in zip(bars, df_reduction['Variance Reduction (%)']):
                ax6.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                        f'{val:.1f}%', ha='center', va='bottom', fontweight='bold')
        
        fig.suptitle(f'Variance Analysis: {problem_type} (n={actual_size})',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        filename = os.path.join(self.output_dir, f'variance_analysis_{problem_type}_{timestamp}.png')
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.savefig(filename.replace('.png', '.pdf'), bbox_inches='tight')
        plt.close()
        
        self.log(f"Variance analysis plot saved to: {filename}")
    
    def _plot_stress_tests(self, results, problem_type, scenario_names, actual_size):
        """Plot stress test results"""
        
        # Create comprehensive stress test comparison
        fig, axes = plt.subplots(2, 2, figsize=(14, 12))
        
        # Prepare data
        all_data = []
        
        for scenario_name in scenario_names:
            for config_key, data in results[scenario_name].items():
                algorithm = 'UCB' if 'UCB' in config_key else 'Thompson'
                ftrl = 'With FTRL' if 'with_FTRL' in config_key else 'Without FTRL'
                
                for hv in data['hypervolume']:
                    all_data.append({
                        'Scenario': scenario_name,
                        'Algorithm': algorithm,
                        'FTRL': ftrl,
                        'Config': f"{algorithm} {ftrl}",
                        'Hypervolume': hv
                    })
        
        df = pd.DataFrame(all_data)
        
        # Plot 1: Heatmap of mean performance
        ax1 = axes[0, 0]
        pivot_mean = df.groupby(['Scenario', 'Config'])['Hypervolume'].mean().unstack()
        sns.heatmap(pivot_mean, annot=True, fmt='.3f', cmap='RdYlGn', ax=ax1)
        ax1.set_title('Mean Hypervolume by Scenario', fontweight='bold')
        ax1.set_ylabel('Stress Scenario')
        
        # Plot 2: Heatmap of std (lower is better)
        ax2 = axes[0, 1]
        pivot_std = df.groupby(['Scenario', 'Config'])['Hypervolume'].std().unstack()
        sns.heatmap(pivot_std, annot=True, fmt='.3f', cmap='RdYlGn_r', ax=ax2)
        ax2.set_title('Std Dev by Scenario (lower=better)', fontweight='bold')
        ax2.set_ylabel('Stress Scenario')
        
        # Plot 3: FTRL advantage (with - without) by scenario
        ax3 = axes[1, 0]
        ftrl_advantage = []
        
        for scenario in scenario_names:
            scenario_df = df[df['Scenario'] == scenario]
            for alg in ['UCB', 'Thompson']:
                with_ftrl = scenario_df[(scenario_df['Algorithm'] == alg) & 
                                       (scenario_df['FTRL'] == 'With FTRL')]['Hypervolume'].mean()
                without_ftrl = scenario_df[(scenario_df['Algorithm'] == alg) & 
                                          (scenario_df['FTRL'] == 'Without FTRL')]['Hypervolume'].mean()
                
                if without_ftrl > 0:
                    advantage = (with_ftrl - without_ftrl) / without_ftrl * 100
                else:
                    advantage = 0
                
                ftrl_advantage.append({
                    'Scenario': scenario,
                    'Algorithm': alg,
                    'FTRL Advantage (%)': advantage
                })
        
        df_advantage = pd.DataFrame(ftrl_advantage)
        pivot_advantage = df_advantage.pivot(index='Scenario', columns='Algorithm', 
                                            values='FTRL Advantage (%)')
        
        x = np.arange(len(scenario_names))
        width = 0.35
        
        ax3.bar(x - width/2, pivot_advantage['UCB'], width, label='UCB', color='#0173B2')
        ax3.bar(x + width/2, pivot_advantage['Thompson'], width, label='Thompson', color='#DE8F05')
        ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        ax3.set_xticks(x)
        ax3.set_xticklabels(scenario_names, rotation=45, ha='right')
        ax3.set_ylabel('FTRL Advantage (%)')
        ax3.set_title('FTRL Performance Advantage by Scenario', fontweight='bold')
        ax3.legend()
        
        # Plot 4: Worst-case (min) performance comparison
        ax4 = axes[1, 1]
        worst_case = []
        
        for scenario in scenario_names:
            for config in df['Config'].unique():
                subset = df[(df['Scenario'] == scenario) & (df['Config'] == config)]
                if len(subset) > 0:
                    worst_case.append({
                        'Scenario': scenario,
                        'Config': config,
                        'Min HV': subset['Hypervolume'].min()
                    })
        
        df_worst = pd.DataFrame(worst_case)
        pivot_worst = df_worst.pivot(index='Scenario', columns='Config', values='Min HV')
        
        pivot_worst.plot(kind='bar', ax=ax4, width=0.8)
        ax4.set_ylabel('Minimum Hypervolume')
        ax4.set_title('Worst-Case Performance by Scenario', fontweight='bold')
        ax4.legend(title='Configuration', bbox_to_anchor=(1.02, 1), loc='upper left')
        ax4.tick_params(axis='x', rotation=45)
        
        fig.suptitle(f'Stress Test Results: {problem_type} (n={actual_size})',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        filename = os.path.join(self.output_dir, f'stress_tests_{problem_type}_{timestamp}.png')
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.savefig(filename.replace('.png', '.pdf'), bbox_inches='tight')
        plt.close()
        
        self.log(f"Stress test plot saved to: {filename}")
    
    def _plot_regret_analysis(self, results, problem_type, actual_size):
        """Plot regret curve analysis"""
        
        fig, axes = plt.subplots(2, 2, figsize=(14, 12))
        
        # Plot 1: Mean cumulative regret curves
        ax1 = axes[0, 0]
        
        colors = {
            'UCB_with_FTRL': '#0173B2',
            'UCB_without_FTRL': '#DE8F05',
            'Thompson_with_FTRL': '#029E73',
            'Thompson_without_FTRL': '#CC78BC'
        }
        
        has_regret_data = False
        for config_name, data in results.items():
            regret_curves = data.get('regret_curves', [])
            # Filter out empty curves
            non_empty_curves = [c for c in regret_curves if c and len(c) > 0]
            if non_empty_curves:
                has_regret_data = True
                max_len = max(len(c) for c in non_empty_curves)
                aligned = []
                for curve in non_empty_curves:
                    if len(curve) < max_len:
                        # Pad with last value
                        curve = list(curve) + [curve[-1]] * (max_len - len(curve))
                    aligned.append(curve[:max_len])
                
                if aligned:
                    aligned = np.array(aligned)
                    mean_regret = np.mean(aligned, axis=0)
                    std_regret = np.std(aligned, axis=0)
                    
                    x = np.arange(len(mean_regret))
                    ax1.plot(x, mean_regret, label=config_name.replace('_', ' '),
                            color=colors.get(config_name, 'gray'), linewidth=2)
                    ax1.fill_between(x, mean_regret - std_regret, mean_regret + std_regret,
                                    alpha=0.2, color=colors.get(config_name, 'gray'))
        
        if not has_regret_data:
            ax1.text(0.5, 0.5, 'No per-iteration data available\n(MOCOEvaluator does not expose iteration details)',
                    ha='center', va='center', transform=ax1.transAxes, fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Cumulative Regret')
        ax1.set_title('Cumulative Regret Curves', fontweight='bold')
        if has_regret_data:
            ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Best reward trajectories
        ax2 = axes[0, 1]
        
        has_best_data = False
        for config_name, data in results.items():
            best_curves = data.get('best_rewards', [])
            # Filter out empty curves
            non_empty_curves = [c for c in best_curves if c and len(c) > 0]
            if non_empty_curves:
                has_best_data = True
                max_len = max(len(c) for c in non_empty_curves)
                aligned = []
                for curve in non_empty_curves:
                    if len(curve) < max_len:
                        curve = list(curve) + [curve[-1]] * (max_len - len(curve))
                    aligned.append(curve[:max_len])
                
                if aligned:
                    aligned = np.array(aligned)
                    mean_best = np.mean(aligned, axis=0)
                    std_best = np.std(aligned, axis=0)
                    
                    x = np.arange(len(mean_best))
                    ax2.plot(x, mean_best, label=config_name.replace('_', ' '),
                            color=colors.get(config_name, 'gray'), linewidth=2)
                    ax2.fill_between(x, mean_best - std_best, mean_best + std_best,
                                    alpha=0.2, color=colors.get(config_name, 'gray'))
        
        if not has_best_data:
            ax2.text(0.5, 0.5, 'No per-iteration data available\n(MOCOEvaluator does not expose iteration details)',
                    ha='center', va='center', transform=ax2.transAxes, fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('Best Reward So Far')
        ax2.set_title('Convergence Trajectories', fontweight='bold')
        if has_best_data:
            ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Final hypervolume comparison (instead of regret which we don't have)
        ax3 = axes[1, 0]
        
        hv_data = []
        for config_name, data in results.items():
            hvs = data.get('hypervolume', [])
            for hv in hvs:
                if hv > 0:
                    hv_data.append({
                        'Config': config_name.replace('_', '\n'),
                        'Algorithm': 'UCB' if 'UCB' in config_name else 'Thompson',
                        'FTRL': 'Without FTRL' if 'without_FTRL' in config_name else 'With FTRL',
                        'Hypervolume': hv
                    })
        
        df_hv = pd.DataFrame(hv_data)
        if len(df_hv) > 0:
            sns.boxplot(data=df_hv, x='Algorithm', y='Hypervolume', hue='FTRL', ax=ax3,
                       palette={'With FTRL': '#0173B2', 'Without FTRL': '#DE8F05'})
            ax3.set_title('Final Hypervolume Distribution', fontweight='bold')
            ax3.legend(title='FTRL Status')
        else:
            ax3.text(0.5, 0.5, 'No hypervolume data available',
                    ha='center', va='center', transform=ax3.transAxes, fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        ax3.grid(True, alpha=0.3, axis='y')
        
        # Plot 4: Stability metrics comparison (use hypervolume stats as proxy)
        ax4 = axes[1, 1]
        
        stability_data = []
        for config_name, data in results.items():
            hvs = data.get('hypervolume', [])
            if hvs:
                valid_hvs = [h for h in hvs if h > 0]
                if valid_hvs:
                    # Use inverse coefficient of variation as stability proxy
                    mean_hv = np.mean(valid_hvs)
                    std_hv = np.std(valid_hvs)
                    stability = mean_hv / (std_hv + 1e-10)  # Higher = more stable
                    
                    stability_data.append({
                        'Config': config_name.replace('_', '\n'),
                        'Algorithm': 'UCB' if 'UCB' in config_name else 'Thompson',
                        'FTRL': 'Without FTRL' if 'without_FTRL' in config_name else 'With FTRL',
                        'Stability': min(stability, 100)  # Cap for visualization
                    })
        
        df_stability = pd.DataFrame(stability_data)
        if len(df_stability) > 0:
            sns.barplot(data=df_stability, x='Algorithm', y='Stability', hue='FTRL', ax=ax4,
                       palette={'With FTRL': '#0173B2', 'Without FTRL': '#DE8F05'})
            ax4.set_title('Result Stability (Mean/Std, higher=better)', fontweight='bold')
            ax4.legend(title='FTRL Status')
        else:
            ax4.text(0.5, 0.5, 'No stability data available',
                    ha='center', va='center', transform=ax4.transAxes, fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        ax4.grid(True, alpha=0.3, axis='y')
        
        fig.suptitle(f'Regret & Convergence Analysis: {problem_type} (n={actual_size})',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        filename = os.path.join(self.output_dir, f'regret_analysis_{problem_type}_{timestamp}.png')
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.savefig(filename.replace('.png', '.pdf'), bbox_inches='tight')
        plt.close()
        
        self.log(f"Regret analysis plot saved to: {filename}")


# ============================================================================
# MAIN RUNNER
# ============================================================================

def run_comprehensive_ftrl_study(
    problem_type: str = 'BiKP',
    problem_size: str = 'medium',
    num_seeds: int = 30,
    output_dir: str = 'ftrl_comprehensive_results',
    studies: List[str] = None
):
    """
    Run comprehensive FTRL study.
    
    Parameters:
    -----------
    problem_type : str
        Problem type ('BiKP', 'BiTSP')
    problem_size : str
        Problem size ('small', 'medium', 'large')
    num_seeds : int
        Number of random seeds per configuration
    output_dir : str
        Output directory for results
    studies : List[str]
        Which studies to run. Options: 'rate', 'variance', 'stress', 'regret'
        If None, runs all studies.
    """
    
    if studies is None:
        studies = ['rate', 'variance', 'stress', 'regret']
    
    print("\n" + "="*80)
    print("COMPREHENSIVE FTRL ABLATION STUDY")
    print("="*80)
    print(f"Problem: {problem_type} ({problem_size})")
    print(f"Seeds per configuration: {num_seeds}")
    print(f"Studies to run: {studies}")
    print(f"Output directory: {output_dir}")
    
    overall_start = time.time()
    
    # Create study instance
    study = FTRLComprehensiveStudy(
        output_dir=output_dir,
        num_seeds=num_seeds,
        verbose=True
    )
    
    results = {}
    
    # Run requested studies
    if 'rate' in studies:
        print("\n" + "="*80)
        print("Running FTRL Rate Study...")
        results['rate'] = study.run_ftrl_rate_study(
            problem_type=problem_type,
            problem_size=problem_size,
            ftrl_rates=[0.0, 0.3, 0.5, 0.7, 1.0]
        )
    
    if 'variance' in studies:
        print("\n" + "="*80)
        print("Running Variance Analysis...")
        results['variance'] = study.run_variance_analysis(
            problem_type=problem_type,
            problem_size=problem_size
        )
    
    if 'stress' in studies:
        print("\n" + "="*80)
        print("Running Stress Tests...")
        results['stress'] = study.run_stress_tests(
            problem_type=problem_type,
            problem_size=problem_size
        )
    
    if 'regret' in studies:
        print("\n" + "="*80)
        print("Running Regret Analysis...")
        results['regret'] = study.run_regret_analysis(
            problem_type=problem_type,
            problem_size=problem_size
        )
    
    overall_time = time.time() - overall_start
    
    print("\n" + "="*80)
    print("COMPREHENSIVE STUDY COMPLETED")
    print("="*80)
    print(f"Total time: {overall_time/60:.2f} minutes")
    print(f"Results saved to: {output_dir}/")
    
    return results




# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    
    # Example 1: Run all studies on BiKP medium
    # results = run_comprehensive_ftrl_study(
    #     problem_type='BiKP',
    #     problem_size='medium',
    #     num_seeds=30,
    #     output_dir='ftrl_comprehensive_results',
    #     studies=['rate', 'variance', 'stress', 'regret']
    # )
    
    # Example 2: Quick test with fewer seeds
    results = run_comprehensive_ftrl_study(
        problem_type='BiTSP',
        problem_size='medium',
        num_seeds=2,  # Reduced for quick testing
        output_dir='ftrl_comprehensive_results_2',
        studies=['rate', 'variance', 'stress', 'regret']  # Run subset of studies
    )
    
    # Example 3: Full study on BiTSP
    # results = run_comprehensive_ftrl_study(
    #     problem_type='BiTSP',
    #     problem_size='medium',
    #     num_seeds=30,
    #     output_dir='ftrl_tsp_results',
    #     studies=['rate', 'variance', 'stress', 'regret']
    # )