"""
    Enhanced ESCORT Framework Evaluation on Multi-modal Correlated Distribution
    
    This script evaluates the improved ESCORT framework against other methods on a 
    challenging multi-modal distribution with complex correlation structures.
    
    Uses multiple random seeds/initializations and reports statistics as mean ± standard error.
    
    Key improvements:
    1. Better mode detection and preservation
    2. Improved correlation-aware updates
    3. Enhanced initialization strategy
    4. Mode-balancing resampling 
    5. Adaptive parameter tuning
    6. Statistical evaluation with multiple runs
"""
import traceback
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from scipy.stats import multivariate_normal, sem
import pandas as pd
import time
import os
import torch
import torch.optim as optim
from tqdm import tqdm
from sklearn.cluster import KMeans, DBSCAN, MeanShift

# Import required libraries - adjust paths as needed
from belief_assessment.distributions import GMMDistribution
from belief_assessment.evaluation.visualize_distributions import GMMVisualizer
from escort.utils.kernels import RBFKernel
from escort.gswd import GSWD
from escort.svgd import SVGD, AdaptiveSVGD
from dvrl.dvrl import DVRL

# Get the directory of the current script for saving outputs
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

# Import evaluation helper functions
try:
    from tests.evaluate_2d_1 import DVRLAdapter, SIRAdapter
    from tests.evaluate_2d_1 import compute_mmd, compute_ess, compute_mode_coverage_2d
    from tests.evaluate_2d_1 import compute_correlation_error, compute_sliced_wasserstein_distance
    from tests.evaluate_2d_1 import estimate_kl_divergence_2d, evaluate_method
except ImportError:
    # In case the imports fail, define the basic adapters here
    # (Only include essential functionality as fallback)
    class DVRLAdapter:
        def __init__(self, dvrl_model, n_samples=1000):
            self.dvrl_model = dvrl_model
            self.n_samples = n_samples
            
        def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
            try:
                h, z, w = self.dvrl_model.init_particles(1, self.n_samples)
                
                if return_convergence:
                    return z.squeeze(0).detach().cpu().numpy(), {}
                else:
                    return z.squeeze(0).detach().cpu().numpy()
            except Exception as e:
                print(f"Error in DVRL: {e}")
                return initial_particles
                
    class SIRAdapter:
        def __init__(self, n_iter=1):
            self.n_iter = n_iter
            
        def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
            try:
                particles = initial_particles.copy()
                
                for iter in range(self.n_iter):
                    # Try to get log probabilities
                    try:
                        _, log_probs = score_fn(particles, return_logp=True)
                    except:
                        log_probs = score_fn(particles)
                    
                    # Compute weights
                    probs = np.exp(log_probs - np.max(log_probs))
                    weights = probs / np.sum(probs)
                    
                    # Compute ESS
                    ess = 1.0 / np.sum(weights**2)
                    ess_ratio = ess / len(particles)
                    
                    # Resample if ESS is too low
                    if ess_ratio < 0.5:
                        print(f"Iteration {iter}: Resampling (ESS = {ess_ratio:.4f})")
                        # Multinomial resampling
                        indices = np.random.choice(
                            len(particles), len(particles), p=weights, replace=True
                        )
                        particles = particles[indices]
                        
                        # Add small noise
                        particles += np.random.randn(*particles.shape) * 0.1
                
                if return_convergence:
                    return particles, {"iterations": self.n_iter}
                else:
                    return particles
            except Exception as e:
                print(f"Error in SIR: {e}")
                return initial_particles
            

class EnhancedSVGD:
    """
    # & Enhanced SVGD with improved mode detection and multi-modal convergence properties
    """
    def __init__(self, kernel=None, step_size=0.01, n_iter=300, tol=1e-5, 
                 mode_aware=True, target_info=None, verbose=True,
                 repulsion_factor=1.5, bandwidth_scale=0.5,
                 use_annealing=True, resample_freq=20):
        """
        # & Initialize enhanced SVGD
        # &
        # & Args:
        # &     kernel: Kernel function (default: RBFKernel with adaptive bandwidth)
        # &     step_size (float): Initial step size
        # &     n_iter (int): Maximum number of iterations
        # &     tol (float): Convergence tolerance
        # &     mode_aware (bool): Whether to use mode-aware updates
        # &     target_info (dict): Information about target distribution
        # &     verbose (bool): Whether to display progress
        # &     repulsion_factor (float): Factor to scale repulsive forces
        # &     bandwidth_scale (float): Scale factor for kernel bandwidth
        # &     use_annealing (bool): Whether to use annealed noise and step size
        # &     resample_freq (int): Frequency of particle rebalancing
        """
        # Create default kernel if not provided
        if kernel is None:
            try:
                from escort.utils.kernels import RBFKernel
                kernel = RBFKernel(bandwidth=None, adaptive=True)
            except ImportError:
                # Simple RBF kernel implementation
                class SimpleRBFKernel:
                    def __init__(self, bandwidth=None, adaptive=True):
                        self.bandwidth = bandwidth
                        self.adaptive = adaptive
                        
                    def evaluate(self, x, y=None):
                        if y is None:
                            y = x
                        
                        x = np.atleast_2d(x)
                        y = np.atleast_2d(y)
                        
                        # Compute pairwise distances
                        squared_dists = np.zeros((x.shape[0], y.shape[0]))
                        for i in range(x.shape[0]):
                            diff = y - x[i]
                            squared_dists[i] = np.sum(diff**2, axis=1)
                        
                        # Adaptively set bandwidth if needed
                        if self.bandwidth is None or self.adaptive:
                            if len(squared_dists.flatten()) > 0:
                                self.bandwidth = np.median(squared_dists) * bandwidth_scale
                                if self.bandwidth < 1e-5:
                                    self.bandwidth = 1.0
                            else:
                                self.bandwidth = 1.0
                        
                        # Compute kernel matrix
                        K = np.exp(-squared_dists / self.bandwidth)
                        return K
                    
                    def gradient(self, x, y=None):
                        if y is None:
                            y = x
                            
                        x = np.atleast_2d(x)
                        y = np.atleast_2d(y)
                        
                        # Compute kernel matrix
                        K = self.evaluate(x, y)
                        
                        # Compute gradient
                        grad = np.zeros((x.shape[0], y.shape[0], x.shape[1]))
                        for i in range(x.shape[0]):
                            for j in range(y.shape[0]):
                                diff = x[i] - y[j]
                                grad[i, j] = K[i, j] * diff / self.bandwidth
                                
                        return grad
                
                kernel = SimpleRBFKernel(bandwidth=None, adaptive=True)
        
        # Store parameters
        self.kernel = kernel
        self.step_size = step_size
        self.n_iter = n_iter
        self.tol = tol
        self.mode_aware = mode_aware
        self.target_info = target_info
        self.verbose = verbose
        self.repulsion_factor = repulsion_factor
        self.bandwidth_scale = bandwidth_scale
        self.use_annealing = use_annealing
        self.resample_freq = resample_freq
        
        # Initialize tracking variables
        self.mode_centers = None
        self.mode_covs = None
        self.mode_assignments = None
        self.iterations_run = 0
    
    def _initialize_particles(self, particles):
        """
        # & Initialize particles with distribution-aware strategy
        # &
        # & Args:
        # &     particles (np.ndarray): Initial particles
        # &
        # & Returns:
        # &     np.ndarray: Initialized particles
        """
        n_particles, dim = particles.shape
        
        # Check if target information is available
        if self.target_info is not None and 'centers' in self.target_info:
            # Extract mode information
            self.mode_centers = self.target_info['centers']
            self.mode_covs = self.target_info.get('covs', None)
            n_modes = len(self.mode_centers)
            
            # Create new particles with even distribution among modes
            new_particles = np.zeros_like(particles)
            
            # Compute particles per mode - ensure even distribution
            particles_per_mode = [n_particles // n_modes] * n_modes
            # Add remainder
            for i in range(n_particles % n_modes):
                particles_per_mode[i] += 1
                
            # Initialize mode assignments array
            self.mode_assignments = np.zeros(n_particles, dtype=int)
            
            # Fill particles for each mode
            start_idx = 0
            for mode_idx in range(n_modes):
                n_mode = particles_per_mode[mode_idx]
                end_idx = start_idx + n_mode
                
                # Get mode parameters
                center = self.mode_centers[mode_idx]
                
                # Get covariance if available
                if self.mode_covs is not None and mode_idx < len(self.mode_covs):
                    cov = self.mode_covs[mode_idx]
                    
                    # Generate correlated samples properly
                    try:
                        # Add small regularization for numerical stability
                        cov_reg = cov + 1e-5 * np.eye(dim)
                        # Use Cholesky decomposition for generating samples
                        L = np.linalg.cholesky(cov_reg)
                        # Generate random normal samples
                        z = np.random.randn(n_mode, dim)
                        # Transform to correlated samples
                        correlated_samples = np.dot(z, L.T) * 0.7  # Scale for tighter clusters
                        # Set particles with center + correlations
                        new_particles[start_idx:end_idx] = center + correlated_samples
                    except:
                        # Fallback to uncorrelated if Cholesky fails
                        new_particles[start_idx:end_idx] = center + np.random.randn(n_mode, dim) * 0.5
                else:
                    # Simple uncorrelated initialization
                    new_particles[start_idx:end_idx] = center + np.random.randn(n_mode, dim) * 0.5
                
                # Assign mode labels
                self.mode_assignments[start_idx:end_idx] = mode_idx
                
                # Update start index
                start_idx = end_idx
            
            return new_particles
        
        # If no target info, initialize with slightly larger spread
        return particles * 1.2
    
    def _update_mode_assignments(self, particles):
        """
        # & Update mode assignments based on particle positions
        # &
        # & Args:
        # &     particles (np.ndarray): Current particles
        """
        if not self.mode_aware or self.mode_centers is None:
            return
        
        n_particles = len(particles)
        n_modes = len(self.mode_centers)
        
        # Initialize mode assignments if needed
        if self.mode_assignments is None or len(self.mode_assignments) != n_particles:
            self.mode_assignments = np.zeros(n_particles, dtype=int)
        
        # Compute distances to each mode center
        distances = np.zeros((n_particles, n_modes))
        for i, center in enumerate(self.mode_centers):
            diff = particles - center
            # Use squared Euclidean distance
            distances[:, i] = np.sum(diff**2, axis=1)
        
        # Assign each particle to nearest mode
        self.mode_assignments = np.argmin(distances, axis=1)
    
    def _balance_modes(self, particles):
        """
        # & Balance particles across modes to prevent mode collapse
        # &
        # & Args:
        # &     particles (np.ndarray): Current particles
        # &
        # & Returns:
        # &     np.ndarray: Balanced particles
        """
        if not self.mode_aware or self.mode_centers is None or self.mode_assignments is None:
            return particles
        
        n_particles = len(particles)
        n_modes = len(self.mode_centers)
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Compute target count per mode (equal distribution)
        target_counts = np.ones(n_modes) * (n_particles / n_modes)
        
        # Find under-represented modes
        for mode_idx in range(n_modes):
            # If too few particles in this mode
            if mode_counts[mode_idx] < target_counts[mode_idx] * 0.7:  # Threshold at 70%
                # Calculate deficit
                deficit = int(target_counts[mode_idx] - mode_counts[mode_idx])
                
                # Find over-represented modes
                over_modes = np.where(mode_counts > target_counts * 1.2)[0]
                
                if len(over_modes) > 0:
                    # Take particles from over-represented modes
                    source_mode = over_modes[0]
                    source_indices = np.where(self.mode_assignments == source_mode)[0]
                    
                    # Number to transfer
                    n_transfer = min(deficit, len(source_indices) - int(target_counts[source_mode] * 0.8))
                    
                    if n_transfer > 0:
                        # Get indices to transfer
                        transfer_indices = source_indices[:n_transfer]
                        
                        # Get mode center and covariance
                        center = self.mode_centers[mode_idx]
                        
                        if self.mode_covs is not None and mode_idx < len(self.mode_covs):
                            cov = self.mode_covs[mode_idx]
                            # Generate correlated samples
                            try:
                                cov_reg = cov + 1e-5 * np.eye(cov.shape[0])
                                L = np.linalg.cholesky(cov_reg)
                                
                                # Add particles to underrepresented mode
                                for idx in transfer_indices:
                                    z = np.random.randn(cov.shape[0])
                                    particles[idx] = center + np.dot(L, z) * 0.5
                                    self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to uncorrelated
                                for idx in transfer_indices:
                                    particles[idx] = center + np.random.randn(len(center)) * 0.5
                                    self.mode_assignments[idx] = mode_idx
                        else:
                            # Simple uncorrelated
                            for idx in transfer_indices:
                                particles[idx] = center + np.random.randn(len(center)) * 0.5
                                self.mode_assignments[idx] = mode_idx
                        
                        # Update mode counts
                        mode_counts[source_mode] -= n_transfer
                        mode_counts[mode_idx] += n_transfer
        
        return particles
    
    def _compute_update(self, particles, score_fn, iteration):
        """
        # & Compute SVGD update with mode-awareness
        # &
        # & Args:
        # &     particles (np.ndarray): Current particles
        # &     score_fn (callable): Score function
        # &     iteration (int): Current iteration number
        # &
        # & Returns:
        # &     np.ndarray: Update direction for particles
        """
        n_particles, dim = particles.shape
        
        # Update mode assignments periodically
        if self.mode_aware and iteration % 5 == 0:
            self._update_mode_assignments(particles)
        
        # Compute score function values
        try:
            score_values = score_fn(particles)
            # Handle invalid values
            score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
        except:
            # Zero scores if function fails
            score_values = np.zeros_like(particles)
        
        # Apply mode-aware guidance
        if self.mode_aware and self.mode_covs is not None and self.mode_assignments is not None:
            for mode_idx, cov in enumerate(self.mode_covs):
                # Skip if invalid mode
                if mode_idx >= len(self.mode_covs):
                    continue
                    
                # Get particles in this mode
                mode_indices = np.where(self.mode_assignments == mode_idx)[0]
                
                if len(mode_indices) > 0:
                    try:
                        # Compute eigendecomposition of covariance matrix
                        eigvals, eigvecs = np.linalg.eigh(cov)
                        # Ensure positive eigenvalues
                        eigvals = np.maximum(eigvals, 1e-6)
                        
                        # Apply correlation-aware scaling to score values
                        for idx in mode_indices:
                            # Project score onto eigenvectors
                            proj_score = eigvecs.T @ score_values[idx]
                            # Scale by eigenvalues
                            proj_score = proj_score * np.sqrt(eigvals)
                            # Project back
                            score_values[idx] = eigvecs @ proj_score
                    except:
                        # Skip if eigenvector computation fails
                        pass
        
        # Compute kernel matrix and its gradient
        K = self.kernel.evaluate(particles)
        grad_K = self.kernel.gradient(particles)
        
        # Handle invalid values
        K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
        grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Compute attractive forces
        attractive = np.zeros_like(particles)
        for i in range(n_particles):
            attractive[i] = np.mean(K[i, :, np.newaxis] * score_values, axis=0)
        
        # Compute repulsive forces
        repulsive = np.zeros_like(particles)
        for i in range(n_particles):
            repulsive[i] = np.mean(grad_K[:, i, :], axis=0)
        
        # Apply dynamic repulsion factor
        if iteration < self.n_iter * 0.3:
            # Strong repulsion early on to ensure mode discovery
            current_repulsion = self.repulsion_factor * 1.5
        elif iteration < self.n_iter * 0.6:
            # Moderate repulsion during mid stages
            current_repulsion = self.repulsion_factor * 1.0
        else:
            # Reduced repulsion during refinement
            current_repulsion = self.repulsion_factor * 0.8
        
        # Apply repulsion
        repulsive *= current_repulsion
        
        # Mode balancing
        if self.mode_aware and self.mode_assignments is not None:
            # Get unique modes and counts
            unique_modes = np.unique(self.mode_assignments)
            mode_counts = np.bincount(self.mode_assignments, minlength=len(unique_modes))
            
            if len(mode_counts) > 0 and np.sum(mode_counts) > 0:
                # Compute ideal count per mode
                ideal_count = n_particles / len(unique_modes)
                
                # Compute weights for each mode
                # Higher weight for under-represented modes
                mode_weights = np.ones_like(mode_counts, dtype=np.float64)
                for i in range(len(mode_counts)):
                    if mode_counts[i] > 0:
                        # More conservative scaling
                        weight = (ideal_count / mode_counts[i]) ** 0.5
                        mode_weights[i] = np.clip(weight, 0.8, 2.0)
                
                # Apply weights to particles based on their mode
                for i in range(n_particles):
                    mode = self.mode_assignments[i]
                    if mode < len(mode_weights):
                        weight = mode_weights[mode]
                        # Apply weight to both attractive and repulsive forces
                        attractive[i] *= weight
                        repulsive[i] *= weight
        
        # Combine forces
        update = attractive + repulsive
        
        # Normalize update for stability
        update_norms = np.linalg.norm(update, axis=1)
        max_norm = 3.0  # Maximum allowed norm
        
        # Scale down large updates
        large_updates = update_norms > max_norm
        if np.any(large_updates):
            scale_factors = max_norm / update_norms[large_updates]
            update[large_updates] *= scale_factors[:, np.newaxis]
        
        return update
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Run enhanced SVGD optimization
        # &
        # & Args:
        # &     initial_particles (np.ndarray): Initial particles
        # &     score_fn (callable): Score function
        # &     target_samples (np.ndarray, optional): Target distribution samples
        # &     return_convergence (bool): Whether to return convergence info
        # &
        # & Returns:
        # &     np.ndarray: Optimized particles
        # &     dict (optional): Convergence information
        """
        # Make a defensive copy of initial particles
        particles = initial_particles.copy()
        n_particles, dim = particles.shape
        
        # Better initialization
        particles = self._initialize_particles(particles)
        
        # Initialize tracking variables
        current_step = self.step_size
        delta_norm_history = []
        step_size_history = []
        
        # Set up noise parameters for annealing
        if self.use_annealing:
            current_noise = 0.1  # Initial noise level
            noise_decay = 0.97  # Noise decay rate
        else:
            current_noise = 0.0
        
        # Progress bar for verbose mode
        iterator = range(self.n_iter)
        if self.verbose:
            try:
                from tqdm import tqdm
                iterator = tqdm(iterator, desc="Enhanced SVGD")
            except ImportError:
                pass
        
        # Main optimization loop
        for t in iterator:
            # Compute SVGD update
            update = self._compute_update(particles, score_fn, t)
            
            # Add noise to prevent mode collapse (annealing schedule)
            if current_noise > 0:
                noise = np.random.randn(*particles.shape) * current_noise
                update = update + noise
                # Decay noise
                current_noise *= noise_decay
            
            # Apply update
            new_particles = particles + current_step * update
            
            # Periodically balance modes
            if self.mode_aware and t > 0 and t % self.resample_freq == 0:
                new_particles = self._balance_modes(new_particles)
            
            # Calculate delta norm for convergence check
            delta = new_particles - particles
            delta_norm = np.linalg.norm(delta) / n_particles
            delta_norm_history.append(delta_norm)
            step_size_history.append(current_step)
            
            # Update particles
            particles = new_particles
            
            # Decay step size
            current_step = self.step_size / (1.0 + 0.02 * t)
            
            # Check for convergence
            if t > self.n_iter * 0.5 and delta_norm < self.tol:
                if self.verbose:
                    print(f"Converged after {t+1} iterations. Delta norm: {delta_norm:.6f}")
                self.iterations_run = t + 1
                break
        else:
            # If loop completes without breaking
            self.iterations_run = self.n_iter
            if self.verbose:
                print(f"Maximum iterations reached. Final delta norm: {delta_norm:.6f}")
        
        # Return requested outputs
        if return_convergence:
            convergence_info = {
                'delta_norm_history': np.array(delta_norm_history),
                'step_size_history': np.array(step_size_history),
                'iterations_run': self.iterations_run,
                'mode_assignments': self.mode_assignments
            }
            return particles, convergence_info
        
        return particles


class EnhancedSVGDAdapter:
    """
    # & Adapter for EnhancedSVGD to match evaluation interface
    """
    def __init__(self, n_iter=300, step_size=0.01, verbose=True, target_info=None):
        """
        # & Initialize adapter
        # &
        # & Args:
        # &     n_iter (int): Number of iterations
        # &     step_size (float): Step size for updates
        # &     verbose (bool): Whether to display progress
        # &     target_info (dict): Information about target distribution
        """
        self.svgd = EnhancedSVGD(
            step_size=step_size,
            n_iter=n_iter,
            verbose=verbose,
            mode_aware=True,
            target_info=target_info,
            repulsion_factor=2.0,  # Stronger repulsion to prevent mode collapse
            bandwidth_scale=0.5,  # More conservative bandwidth
            use_annealing=True,    # Use annealing for better exploration
            resample_freq=15       # More frequent rebalancing
        )
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Run optimization
        # &
        # & Args:
        # &     initial_particles (np.ndarray): Initial particles
        # &     score_fn (callable): Score function
        # &     target_samples (np.ndarray, optional): Target distribution samples
        # &     return_convergence (bool): Whether to return convergence info
        # &
        # & Returns:
        # &     np.ndarray: Optimized particles
        # &     dict (optional): Convergence information
        """
        return self.svgd.fit_transform(
            initial_particles, score_fn, target_samples, return_convergence)


class StableSVGD(SVGD):
    """
    Improved SVGD with numerical stability guarantees
    """
    def __init__(self, kernel=None, step_size=0.01, n_iter=300, tol=1e-5, 
                 bandwidth_scale=0.5, add_noise=True, noise_level=0.1, 
                 noise_decay=0.98, resample_freq=20, adaptive_step=True, 
                 mode_detection=True, lambda_corr=0.2, verbose=True,
                 target_info=None):
        """
        Initialize stable SVGD
        
        Args:
            kernel: Kernel function (default: RBFKernel with adaptive bandwidth)
            step_size (float): Initial step size
            n_iter (int): Maximum number of iterations
            tol (float): Convergence tolerance
            bandwidth_scale (float): Scale factor for kernel bandwidth
            add_noise (bool): Whether to add noise during optimization
            noise_level (float): Initial noise level
            noise_decay (float): Decay rate for noise
            resample_freq (int): Frequency of resampling (iterations)
            adaptive_step (bool): Whether to use adaptive step size
            mode_detection (bool): Whether to detect and preserve modes
            lambda_corr (float): Weight for correlation preservation
            verbose (bool): Whether to display progress
            target_info (dict): Information about target distribution
        """
        # Create default kernel if not provided
        if kernel is None:
            kernel = RBFKernel(bandwidth=1.0, adaptive=True)
            
        # Initialize parent class
        super().__init__(kernel=kernel, step_size=step_size, 
                          n_iter=n_iter, tol=tol, verbose=verbose)
        
        # Store additional parameters
        self.bandwidth_scale = bandwidth_scale
        self.add_noise = add_noise
        self.noise_level = noise_level
        self.noise_decay = noise_decay
        self.resample_freq = resample_freq
        self.adaptive_step = adaptive_step
        self.mode_detection = mode_detection
        self.lambda_corr = lambda_corr
        
        # Store target information if available
        self.target_info = target_info
        
        # Initialize mode-related attributes
        self.detected_modes = None
        self.mode_assignments = None
        self._mode_centers = None
        self._mode_covs = None
        
        # Keep track of iterations run
        self.iterations_run = 0
        
    def initialize_particles(self, particles):
        """
        Initialize particles with improved strategy
        
        Args:
            particles (np.ndarray): Initial particles
            
        Returns:
            np.ndarray: Initialized particles
        """
        n_particles, dim = particles.shape
        
        # If we have target info with known modes, use it
        if self.target_info is not None and 'centers' in self.target_info:
            centers = self.target_info['centers']
            covs = self.target_info.get('covs', None)
            
            # Store mode information for later use
            self._mode_centers = centers
            self._mode_covs = covs
            
            # Number of modes
            n_modes = len(centers)
            
            # Distribute particles evenly across modes
            new_particles = np.zeros((n_particles, dim))
            particles_per_mode = [n_particles // n_modes] * n_modes
            
            # Distribute any remainder
            for i in range(n_particles % n_modes):
                particles_per_mode[i] += 1
            
            # Initialize mode assignments
            self.mode_assignments = np.zeros(n_particles, dtype=int)
            
            idx = 0
            for i in range(n_modes):
                n_mode = particles_per_mode[i]
                
                # Initialize particles around this mode
                if covs is not None and i < len(covs):
                    # Generate correlated samples using covariance
                    try:
                        cov = covs[i]
                        # Add small regularization for numerical stability
                        cov_reg = cov + 1e-5 * np.eye(dim)
                        # Compute Cholesky decomposition
                        L = np.linalg.cholesky(cov_reg)
                        # Generate correlated random samples
                        z = np.random.randn(n_mode, dim)
                        correlated = z @ L.T * 0.6  # Scale for tighter initial clusters
                        new_particles[idx:idx+n_mode] = centers[i] + correlated
                    except:
                        # Fallback to isotropic if Cholesky fails
                        new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.5
                else:
                    # Use isotropic normal if no covariance available
                    new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.5
                
                # Assign mode labels
                self.mode_assignments[idx:idx+n_mode] = i
                
                idx += n_mode
            
            return new_particles
        
        # If no target info, just return original particles
        return particles
    
    def _update_mode_assignments(self, particles):
        """
        Update mode assignments based on mode centers
        
        Args:
            particles (np.ndarray): Current particles
            
        Returns:
            tuple: (mode_centers, mode_assignments)
        """
        if self._mode_centers is None:
            return None, None
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        
        # Initialize or reset mode assignments
        if self.mode_assignments is None or len(self.mode_assignments) != n_particles:
            self.mode_assignments = np.zeros(n_particles, dtype=int)
        
        # Compute distances to each mode center - vectorized for efficiency
        distances = np.zeros((n_particles, n_modes))
        for i, center in enumerate(self._mode_centers):
            diff = particles - center
            distances[:, i] = np.sum(diff**2, axis=1)
        
        # Assign each particle to nearest center
        self.mode_assignments = np.argmin(distances, axis=1)
        
        return self._mode_centers, self.mode_assignments
    
    def _systematic_resample(self, particles, weights):
        """
        Safe systematic resampling with duplication avoidance
        
        Args:
            particles (np.ndarray): Current particles
            weights (np.ndarray): Particle weights
            
        Returns:
            np.ndarray: Resampled particles
        """
        n_particles = len(particles)
        new_particles = np.zeros_like(particles)
        
        # Normalize weights for safety
        weights = np.array(weights, dtype=np.float64)
        if np.sum(weights) <= 0:
            # Uniform weights if all weights are invalid
            weights = np.ones(n_particles) / n_particles
        else:
            weights = weights / np.sum(weights)
        
        # Systematic resampling
        positions = (np.random.random() + np.arange(n_particles)) / n_particles
        cumulative_sum = np.cumsum(weights)
        
        i, j = 0, 0
        while i < n_particles:
            while j < len(cumulative_sum) and positions[i] > cumulative_sum[j]:
                j += 1
            
            if j >= len(cumulative_sum):
                j = len(cumulative_sum) - 1
                
            new_particles[i] = particles[j].copy()
            i += 1
        
        # Add small jitter to avoid duplicates
        noise_scale = 0.05  # Smaller noise to maintain stability
        new_particles += np.random.randn(*new_particles.shape) * noise_scale
        
        return new_particles
    
    def _mode_balanced_resample(self, particles):
        """
        Simple mode-balancing resampling strategy
        
        Args:
            particles (np.ndarray): Current particles
            
        Returns:
            np.ndarray: Resampled particles
        """
        if self._mode_centers is None or self.mode_assignments is None:
            return particles
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        new_particles = particles.copy()
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Target count per mode (uniform distribution)
        target_count = n_particles / n_modes
        
        # Find under-represented modes
        for mode_idx in range(n_modes):
            # If this mode has too few particles
            if mode_counts[mode_idx] < target_count * 0.8:  # 20% threshold
                mode_deficit = int(target_count - mode_counts[mode_idx])
                
                # Generate new particles around this mode center
                mode_center = self._mode_centers[mode_idx]
                
                # Try to use covariance structure if available
                cov = None
                if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                    cov = self._mode_covs[mode_idx]
                
                # Find particles from over-represented modes to replace
                other_modes = np.where(mode_counts > target_count)[0]
                if len(other_modes) > 0:
                    # Get particles from over-represented modes
                    replace_mode = other_modes[0]
                    replace_indices = np.where(self.mode_assignments == replace_mode)[0]
                    
                    # Replace a subset of these particles
                    n_replace = min(mode_deficit, len(replace_indices))
                    replace_indices = replace_indices[:n_replace]
                    
                    # Generate new particles for under-represented mode
                    for i, idx in enumerate(replace_indices):
                        if cov is not None:
                            try:
                                # Generate correlated sample
                                L = np.linalg.cholesky(cov + 1e-5 * np.eye(cov.shape[0]))
                                new_particles[idx] = mode_center + np.random.randn(cov.shape[0]) @ L.T * 0.5
                            except:
                                # Fallback to isotropic
                                new_particles[idx] = mode_center + np.random.randn(len(mode_center)) * 0.5
                        else:
                            # Isotropic normal if no covariance
                            new_particles[idx] = mode_center + np.random.randn(len(mode_center)) * 0.5
                        
                        # Update mode assignment
                        self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[replace_mode] -= n_replace
                    mode_counts[mode_idx] += n_replace
        
        return new_particles
    
    def _compute_svgd_update(self, particles, score_fn, iteration=0):
        """
        Compute safe SVGD update with stability guarantees
        
        Args:
            particles (np.ndarray): Current particles
            score_fn (callable): Score function
            iteration (int): Current iteration number
            
        Returns:
            np.ndarray: Update direction for particles
        """
        n_particles, dim = particles.shape
        
        # Update mode assignments periodically
        if self.mode_detection and (iteration == 0 or iteration % 10 == 0):
            self._update_mode_assignments(particles)
        
        # Get score function values with error handling
        try:
            score_values = score_fn(particles)
            
            # Check for NaN or inf in score values
            if np.any(np.isnan(score_values)) or np.any(np.isinf(score_values)):
                # Replace problematic values with zeros
                score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
        except:
            # If score function fails, use zero scores
            score_values = np.zeros_like(particles)
        
        # Apply a cautious approach to correlation guidance
        if dim == 2 and self._mode_covs is not None and self.mode_assignments is not None:
            # Apply correlation guidance based on mode covariance - carefully
            for mode_idx, cov in enumerate(self._mode_covs):
                # Get particles in this mode
                mode_mask = self.mode_assignments == mode_idx
                mode_indices = np.where(mode_mask)[0]
                
                if len(mode_indices) > 0:
                    # Compute eigendecomposition of covariance
                    try:
                        eigvals, eigvecs = np.linalg.eigh(cov)
                        # Ensure positive eigenvalues
                        eigvals = np.maximum(eigvals, 1e-6)
                        
                        # Apply correlation-aware scaling to score values
                        # More cautiously now
                        for idx in mode_indices:
                            # Project score onto eigenvectors
                            proj_score = eigvecs.T @ score_values[idx]
                            # Scale by sqrt of eigenvalues - more cautious scaling
                            proj_score = proj_score * np.sqrt(eigvals) * 0.8
                            # Project back
                            score_values[idx] = eigvecs @ proj_score
                    except:
                        # Skip if eigendecomposition fails
                        pass
        
        # Get kernel matrix and gradient
        K = self.kernel.evaluate(particles)
        grad_K = self.kernel.gradient(particles)
        
        # Check for NaN or inf in kernel matrices
        if np.any(np.isnan(K)) or np.any(np.isinf(K)):
            # Replace problematic values
            K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
        
        if np.any(np.isnan(grad_K)) or np.any(np.isinf(grad_K)):
            # Replace problematic values
            grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Compute attractive forces
        attractive = np.zeros_like(particles)
        for i in range(n_particles):
            attractive[i] = np.sum(K[i, :, np.newaxis] * score_values, axis=0)
        
        # Compute repulsive forces
        repulsive = np.zeros_like(particles)
        for i in range(n_particles):
            repulsive[i] = np.sum(grad_K[:, i, :], axis=0)
        
        # Apply a more stable repulsion factor
        repulsion_factor = 1.0  # Start with neutral factor
        
        # Dynamic adjustment based on iteration
        if iteration < self.n_iter * 0.3:
            # More exploration early on, but not too aggressive
            repulsion_factor = 1.5  
        
        # Apply repulsion factor
        repulsive *= repulsion_factor
        
        # Mode balancing with stability safeguards
        if self.mode_detection and self.mode_assignments is not None:
            unique_modes = np.unique(self.mode_assignments)
            mode_counts = np.bincount(self.mode_assignments, minlength=len(unique_modes))
            
            if len(mode_counts) > 0 and np.any(mode_counts > 0):
                # Compute target counts (uniform across modes)
                target_counts = np.ones_like(mode_counts) * (n_particles / len(unique_modes))
                
                # Compute mode weights more conservatively
                mode_weights = (target_counts / np.maximum(mode_counts, 1)) ** 1.0  # Less aggressive exponent
                mode_weights = np.clip(mode_weights, 0.8, 1.5)  # More conservative clipping
                
                # Apply weights to each particle based on its mode
                for mode_idx, weight in enumerate(mode_weights):
                    mode_mask = self.mode_assignments == mode_idx
                    attractive[mode_mask] *= weight
                    repulsive[mode_mask] *= weight
        
        # Combine attractive and repulsive terms
        update = attractive + repulsive
        
        # Safety: Clip extreme values in the update
        # First check for problematic values
        if np.any(np.isnan(update)) or np.any(np.isinf(update)):
            update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Clip extreme updates - compute norm of each update
        update_norms = np.sqrt(np.sum(update**2, axis=1))
        max_norm = 5.0  # Maximum allowed update norm
        
        # Scale down updates that are too large
        large_updates = update_norms > max_norm
        if np.any(large_updates):
            scale_factors = max_norm / update_norms[large_updates]
            update[large_updates] *= scale_factors[:, np.newaxis]
        
        return update
    
    def update(self, particles, score_fn, target_samples=None, return_convergence=False):
        """
        Run stable SVGD optimization
        
        Args:
            particles (np.ndarray): Initial particles
            score_fn (callable): Score function
            target_samples (np.ndarray, optional): Target distribution samples
            return_convergence (bool): Whether to return convergence info
            
        Returns:
            np.ndarray: Optimized particles
            dict (optional): Convergence information
        """
        # Make a copy of initial particles
        particles = particles.copy()
        n_particles, dim = particles.shape
        
        # Better initialization
        particles = self.initialize_particles(particles)
        
        # Tracking variables
        curr_step_size = self.step_size
        current_noise = self.noise_level if self.add_noise else 0.0
        delta_norm_history = []
        step_size_history = []
        
        # Set up progress bar if verbose
        iterator = range(self.n_iter)
        if self.verbose:
            try:
                iterator = tqdm(iterator, desc="Improved SVGD")
            except ImportError:
                pass
        
        # Main optimization loop
        for t in iterator:
            # Compute SVGD update
            update = self._compute_svgd_update(particles, score_fn, t)
            
            # Add small random noise occasionally to help exploration
            if current_noise > 0 and (t % 5 == 0):
                noise = np.random.randn(*particles.shape) * current_noise
                # Also clip noise
                noise_norms = np.sqrt(np.sum(noise**2, axis=1))
                max_noise_norm = 0.5
                large_noise = noise_norms > max_noise_norm
                if np.any(large_noise):
                    scale_factors = max_noise_norm / noise_norms[large_noise]
                    noise[large_noise] *= scale_factors[:, np.newaxis]
                
                update = update + noise
                
                # Decay noise
                current_noise *= self.noise_decay
            
            # Apply update with step size
            new_particles = particles + curr_step_size * update
            
            # Regular resampling to maintain diverse modes
            if t > 0 and t % self.resample_freq == 0:
                # Check ESS first
                try:
                    # Compute weights based on scores
                    score_vals = score_fn(new_particles)
                    max_score = np.max(score_vals)
                    weights = np.exp(score_vals - max_score)
                    weights = weights / np.sum(weights)
                    
                    # Compute ESS
                    ess = 1.0 / np.sum(weights**2)
                    ess_ratio = ess / n_particles
                    
                    # Only resample if ESS is low
                    if ess_ratio < 0.3:
                        if self.verbose:
                            print(f"Resampling (ESS = {ess_ratio:.4f})")
                        
                        # Mode-aware resampling
                        if self.mode_detection and self._mode_centers is not None:
                            new_particles = self._mode_balanced_resample(new_particles)
                        else:
                            # Standard resampling
                            new_particles = self._systematic_resample(new_particles, weights)
                except Exception as e:
                    # If ESS computation fails, do cautious mode balancing
                    if self.mode_detection and self._mode_centers is not None:
                        new_particles = self._mode_balanced_resample(new_particles)
            
            # Compute convergence metric - with safety checks
            delta = new_particles - particles
            delta_norm = np.linalg.norm(delta) / (n_particles * dim)
            
            # Check for numerical issues
            if np.isnan(delta_norm) or np.isinf(delta_norm) or delta_norm > 1e10:
                # If update is unstable, reduce step size drastically and retry
                curr_step_size *= 0.1
                if self.verbose:
                    print(f"Unstable update detected! Reducing step size to {curr_step_size:.6f}")
                
                # Skip this update
                continue
            
            # Record history
            delta_norm_history.append(delta_norm)
            step_size_history.append(curr_step_size)
            
            # Update particles
            particles = new_particles
            
            # Step size decay if adaptive
            if self.adaptive_step:
                curr_step_size = self.step_size / (1.0 + 0.02 * t)
            
            # Check for convergence
            if t > self.n_iter // 2 and delta_norm < self.tol:
                if self.verbose:
                    print(f"Converged after {t+1} iterations. Delta norm: {delta_norm:.6f}")
                self.iterations_run = t + 1
                break
        else:
            # If loop completes without breaking
            self.iterations_run = self.n_iter
            if self.verbose:
                print(f"Maximum iterations reached. Final delta norm: {delta_norm:.6f}")
        
        if return_convergence:
            convergence_info = {
                'delta_norm_history': np.array(delta_norm_history),
                'step_size_history': np.array(step_size_history),
                'iterations_run': self.iterations_run
            }
            return particles, convergence_info
        
        return particles
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        Interface method to match other algorithms
        
        Args:
            initial_particles (np.ndarray): Initial particles
            score_fn (callable): Score function
            target_samples (np.ndarray, optional): Target distribution samples
            return_convergence (bool): Whether to return convergence info
            
        Returns:
            np.ndarray: Optimized particles
            dict (optional): Convergence information
        """
        return self.update(initial_particles, score_fn, target_samples, return_convergence)


class StableSVGDAdapter:
    """Adapter for StableSVGD to match interface with other methods"""
    def __init__(self, n_iter=300, step_size=0.01, verbose=True, target_info=None):
        self.svgd = StableSVGD(
            step_size=step_size,
            n_iter=n_iter,
            verbose=verbose,
            add_noise=True,
            noise_level=0.05,  # Much lower noise level
            noise_decay=0.95,
            resample_freq=20,
            adaptive_step=True,
            mode_detection=True,
            lambda_corr=0.1,  # Reduced correlation factor
            target_info=target_info
        )
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        return self.svgd.fit_transform(
            initial_particles, score_fn, target_samples, return_convergence)


# ========================================
# Enhanced EnhancedKernel Implementation
# ========================================

class EnhancedRBFKernel(RBFKernel):
    """
    # & Enhanced RBF Kernel with adaptive bandwidth and mode-awareness
    """
    def __init__(self, bandwidth=None, adaptive=True, min_bandwidth=0.1, max_bandwidth=10.0):
        """
        # & Initialize the enhanced kernel
        # &
        # & Args:
        # &     bandwidth (float, optional): Initial bandwidth
        # &     adaptive (bool): Whether to adaptively adjust bandwidth
        # &     min_bandwidth (float): Minimum allowed bandwidth
        # &     max_bandwidth (float): Maximum allowed bandwidth
        """
        super().__init__(bandwidth, adaptive)
        self.min_bandwidth = min_bandwidth
        self.max_bandwidth = max_bandwidth
        self.mode_assignments = None
        self.mode_bandwidths = None
    
    def set_mode_information(self, mode_assignments, detected_modes):
        """
        # & Set mode assignments for mode-aware kernel computations
        # &
        # & Args:
        # &     mode_assignments (np.ndarray): Mode assignments for each particle
        # &     detected_modes (np.ndarray): Detected mode centers
        """
        self.mode_assignments = mode_assignments
        
        # Compute mode-specific bandwidths if we have all the information
        if mode_assignments is not None and detected_modes is not None:
            unique_modes = np.unique(mode_assignments)
            self.mode_bandwidths = {}
            
            for mode in unique_modes:
                if mode >= 0:  # Skip noise points (-1)
                    # Get particles in this mode
                    mode_particles = np.where(mode_assignments == mode)[0]
                    
                    if len(mode_particles) > 10:
                        # Compute median distance within this mode
                        particles_in_mode = mode_particles[:min(len(mode_particles), 100)]  # Limit for efficiency
                        dists = []
                        for i in range(len(particles_in_mode)):
                            for j in range(i+1, len(particles_in_mode)):
                                xi = particles_in_mode[i]
                                xj = particles_in_mode[j]
                                dists.append(np.sum((xi - xj)**2))
                        
                        if dists:
                            # Set bandwidth for this mode
                            median_dist = np.median(dists)
                            self.mode_bandwidths[mode] = median_dist
    
    def evaluate(self, x, y=None):
        """
        # & Evaluate the kernel matrix with mode-awareness
        # &
        # & Args:
        # &     x (np.ndarray): First set of points
        # &     y (np.ndarray, optional): Second set of points
        # &
        # & Returns:
        # &     np.ndarray: Kernel matrix
        """
        if y is None:
            y = x
            
        x = np.atleast_2d(x)
        y = np.atleast_2d(y)
        
        # Adaptively set bandwidth if not provided
        if self.bandwidth is None or self.adaptive:
            if x.shape[0] > 1:
                # Compute squared distances
                if x is y:
                    # Fast computation for symmetric case
                    dists = []
                    for i in range(min(1000, x.shape[0])):  # Limit computation for large datasets
                        xi = x[i]
                        diff = x - xi
                        dists.extend(np.sum(diff**2, axis=1).tolist())
                else:
                    # Different sets
                    dists = []
                    for i in range(min(100, x.shape[0])):  # Sample for efficiency
                        xi = x[i]
                        diff = y - xi
                        dists.extend(np.sum(diff**2, axis=1).tolist())
                
                # Use median heuristic
                if dists:
                    self.bandwidth = np.median(dists) / np.log(max(x.shape[0], 2))
                    # Ensure bandwidth is in valid range
                    self.bandwidth = np.clip(self.bandwidth, self.min_bandwidth, self.max_bandwidth)
            else:
                # Default for single point
                self.bandwidth = 1.0
        
        # Compute kernel matrix
        K = np.zeros((x.shape[0], y.shape[0]))
        
        # Check if we're using mode-aware computations
        use_mode_aware = (self.mode_assignments is not None and 
                          self.mode_bandwidths is not None and
                          len(self.mode_bandwidths) > 0)
        
        if use_mode_aware and x is y:  # Only for self-evaluation
            # Mode-aware kernel computation
            for i in range(x.shape[0]):
                for j in range(y.shape[0]):
                    # Get squared distance
                    dist_sq = np.sum((x[i] - y[j])**2)
                    
                    # Determine bandwidth based on modes
                    if i < len(self.mode_assignments) and j < len(self.mode_assignments):
                        mode_i = self.mode_assignments[i]
                        mode_j = self.mode_assignments[j]
                        
                        if mode_i == mode_j and mode_i in self.mode_bandwidths:
                            # Same mode: use mode-specific bandwidth
                            h = self.mode_bandwidths[mode_i]
                        else:
                            # Different modes: use global bandwidth
                            h = self.bandwidth
                    else:
                        h = self.bandwidth
                    
                    # Compute kernel value
                    K[i, j] = np.exp(-dist_sq / h)
        else:
            # Standard kernel computation for efficiency
            for i in range(x.shape[0]):
                # Compute all distances in one go
                diff = y - x[i]
                dist_sq = np.sum(diff**2, axis=1)
                K[i] = np.exp(-dist_sq / self.bandwidth)
        
        return K
    
    def gradient(self, x, y=None):
        """
        # & Compute gradient of kernel w.r.t. second argument
        # &
        # & Args:
        # &     x (np.ndarray): First set of points
        # &     y (np.ndarray, optional): Second set of points
        # &
        # & Returns:
        # &     np.ndarray: Gradient of kernel w.r.t. y
        """
        if y is None:
            y = x
            
        x = np.atleast_2d(x)
        y = np.atleast_2d(y)
        
        # Compute kernel matrix
        K = self.evaluate(x, y)
        
        # Initialize gradient array
        grad = np.zeros((x.shape[0], y.shape[0], y.shape[1]))
        
        # Check if we're using mode-aware computations
        use_mode_aware = (self.mode_assignments is not None and 
                          self.mode_bandwidths is not None and
                          len(self.mode_bandwidths) > 0)
        
        if use_mode_aware and x is y:  # Only for self-evaluation
            # Mode-aware gradient computation
            for i in range(x.shape[0]):
                for j in range(y.shape[0]):
                    # Determine bandwidth based on modes
                    if i < len(self.mode_assignments) and j < len(self.mode_assignments):
                        mode_i = self.mode_assignments[i]
                        mode_j = self.mode_assignments[j]
                        
                        if mode_i == mode_j and mode_i in self.mode_bandwidths:
                            # Same mode: use mode-specific bandwidth with enhanced repulsion
                            h = self.mode_bandwidths[mode_i]
                            # Stronger repulsion within same mode
                            scaling = 1.2
                        else:
                            # Different modes: use global bandwidth with normal repulsion
                            h = self.bandwidth
                            # Normal repulsion between different modes
                            scaling = 1.0
                    else:
                        h = self.bandwidth
                        scaling = 1.0
                    
                    # Compute gradient: K(x,y) * (x-y) / h
                    grad[i, j] = scaling * K[i, j] * (x[i] - y[j]) / h
        else:
            # Standard gradient computation for efficiency
            for i in range(x.shape[0]):
                # Calculate all gradients in one go
                diff = x[i] - y
                grad[i] = K[i, :, np.newaxis] * diff / self.bandwidth
        
        return grad


# ========================================
# Improved GSWD Implementation
# ========================================

class ImprovedGSWD(GSWD):
    """
    # & Enhanced GSWD with better correlation structure preservation
    """
    def __init__(self, n_projections=20, projection_method='optimized', 
                optimization_steps=10, correlation_aware=True, mode_aware=True):
        """
        # & Initialize improved GSWD
        # &
        # & Args:
        # &     n_projections (int): Number of projection directions
        # &     projection_method (str): Method to generate projections
        # &     optimization_steps (int): Steps for projection optimization
        # &     correlation_aware (bool): Whether to optimize for correlation preservation
        # &     mode_aware (bool): Whether to use mode-specific projections
        """
        # Pass only the parameters that the parent class accepts
        super().__init__(n_projections, projection_method, optimization_steps, correlation_aware)
        
        # Store mode_aware as an instance variable
        self.mode_aware = mode_aware
        self.mode_assignments = None
        self.mode_projections = {}
        
        # If we want to use a custom learning_rate in our methods
        self.learning_rate = 0.01  # Default value, can be overridden in methods
    
    def set_mode_information(self, mode_assignments, detected_modes):
        """
        # & Set mode information for mode-aware projections
        # &
        # & Args:
        # &     mode_assignments (np.ndarray): Mode assignments for each particle
        # &     detected_modes (np.ndarray): Detected mode centers
        """
        self.mode_assignments = mode_assignments
        
        # Generate mode-specific projections if needed
        if self.mode_aware and mode_assignments is not None and detected_modes is not None:
            unique_modes = np.unique(mode_assignments)
            unique_modes = unique_modes[unique_modes >= 0]  # Skip noise points (-1)
            
            for mode in unique_modes:
                # Get particles in this mode
                mode_particles = np.where(mode_assignments == mode)[0]
                
                if len(mode_particles) > 10:
                    try:
                        # Generate projections specific to this mode
                        self.mode_projections[mode] = self._generate_projections(
                            n_projections=max(5, self.n_projections // len(unique_modes)),
                            dim=detected_modes.shape[1]
                        )
                    except:
                        # Fallback if generation fails
                        self.mode_projections[mode] = None
    
    def _optimize_projections(self, source, target):
        """
        # & Enhanced projection optimization to better preserve correlation
        # &
        # & Args:
        # &     source (np.ndarray): Source samples
        # &     target (np.ndarray): Target samples
        # &
        # & Returns:
        # &     np.ndarray: Optimized projection matrices
        """
        dim = target.shape[1]
        
        # Initialize projections
        if self.projection_method == 'random':
            # Random unit vectors
            projections = np.random.randn(self.n_projections, dim)
            projections = projections / np.linalg.norm(projections, axis=1, keepdims=True)
        elif self.projection_method == 'pca':
            # Use PCA directions if possible
            try:
                from sklearn.decomposition import PCA
                pca = PCA(n_components=min(dim, 10)).fit(target)
                pca_dirs = pca.components_
                
                # Fill remaining with random directions
                n_pca = pca_dirs.shape[0]
                n_random = self.n_projections - n_pca
                
                if n_random > 0:
                    random_dirs = np.random.randn(n_random, dim)
                    random_dirs = random_dirs / np.linalg.norm(random_dirs, axis=1, keepdims=True)
                    projections = np.vstack([pca_dirs, random_dirs])
                else:
                    projections = pca_dirs[:self.n_projections]
            except:
                # Fallback to random if PCA fails
                projections = np.random.randn(self.n_projections, dim)
                projections = projections / np.linalg.norm(projections, axis=1, keepdims=True)
        else:
            # Default: random
            projections = np.random.randn(self.n_projections, dim)
            projections = projections / np.linalg.norm(projections, axis=1, keepdims=True)
        
        # If correlation-aware, optimize projections to preserve correlation
        if self.correlation_aware and self.optimization_steps > 0:
            # Compute covariance matrices for target and source
            target_cov = np.cov(target, rowvar=False)
            source_cov = np.cov(source, rowvar=False)
            
            # Ensure positive definiteness
            target_cov = target_cov + 1e-5 * np.eye(dim)
            source_cov = source_cov + 1e-5 * np.eye(dim)
            
            # Perform eigendecomposition
            target_evals, target_evecs = np.linalg.eigh(target_cov)
            source_evals, source_evecs = np.linalg.eigh(source_cov)
            
            # Include eigenvectors as projection directions
            n_eigen = min(dim, self.n_projections // 2)
            eigen_projections = np.vstack([
                target_evecs[:, -n_eigen:].T,  # Top eigenvectors of target
                source_evecs[:, -n_eigen:].T   # Top eigenvectors of source
            ])
            
            # Replace some random projections with eigenvector-based ones
            n_eigen_total = eigen_projections.shape[0]
            if n_eigen_total > 0:
                projections[:n_eigen_total] = eigen_projections
            
            # Optimize remaining projections by gradient descent
            for step in range(self.optimization_steps):
                # Compute distances along each projection
                target_proj = np.dot(target, projections.T)
                source_proj = np.dot(source, projections.T)
                
                # Compute Wasserstein distances for each projection
                w_dists = []
                for i in range(projections.shape[0]):
                    # Sort projected values
                    target_sorted = np.sort(target_proj[:, i])
                    source_sorted = np.sort(source_proj[:, i])
                    
                    # Compute 1D Wasserstein distance (L1 between sorted samples)
                    dist = np.mean(np.abs(target_sorted - source_sorted))
                    w_dists.append(dist)
                
                # Select projections to update based on distances
                update_indices = np.argsort(w_dists)[self.n_projections//2:]
                
                # Update only selected projections with gradient step
                for idx in update_indices:
                    # Compute gradient for this projection
                    grad = np.zeros(dim)
                    
                    # Compute projected values and ranks
                    t_proj = target_proj[:, idx]
                    s_proj = source_proj[:, idx]
                    
                    # Create pairs of projected points
                    t_sort_idx = np.argsort(t_proj)
                    s_sort_idx = np.argsort(s_proj)
                    
                    # Compute gradient
                    for i in range(min(len(target), len(source))):
                        # Get matched points from sorted projections
                        t_idx = t_sort_idx[i]
                        s_idx = s_sort_idx[i]
                        
                        if t_proj[t_idx] > s_proj[s_idx]:
                            grad += target[t_idx] - source[s_idx]
                        else:
                            grad += source[s_idx] - target[t_idx]
                    
                    # Apply gradient step
                    projections[idx] = projections[idx] - self.learning_rate * grad
                    # Normalize to unit length
                    projections[idx] = projections[idx] / np.linalg.norm(projections[idx])
        
        return projections
    
    def get_regularizer(self, target, source, lambda_reg=0.1):
        """
        # & Compute enhanced GSWD regularization term with mode awareness
        # &
        # & Args:
        # &     target (np.ndarray): Target samples
        # &     source (np.ndarray): Source samples
        # &     lambda_reg (float): Regularization strength
        # &
        # & Returns:
        # &     np.ndarray: Regularization updates for particles
        """
        # Initialize regularization updates
        reg_updates = np.zeros_like(source)
        
        # Check if we're using mode-aware projections
        use_mode_aware = (self.mode_aware and 
                           self.mode_assignments is not None and 
                           len(self.mode_projections) > 0)
        
        if use_mode_aware:
            # Apply mode-specific regularization
            unique_modes = np.unique(self.mode_assignments)
            unique_modes = unique_modes[unique_modes >= 0]  # Skip noise points (-1)
            
            for mode in unique_modes:
                # Get particles in this mode
                mode_mask = self.mode_assignments == mode
                mode_particles = source[mode_mask]
                
                # Skip if too few particles
                if len(mode_particles) < 5:
                    continue
                
                # Get projections for this mode
                mode_proj = self.mode_projections.get(mode)
                if mode_proj is None:
                    # Use global projections if mode-specific ones not available
                    mode_proj = self.projections
                
                # Apply regularization for this mode
                for i, proj in enumerate(mode_proj):
                    # Project particles
                    target_proj = np.dot(target, proj)
                    source_proj = np.dot(mode_particles, proj)
                    
                    # Sort projected values
                    target_sorted = np.sort(target_proj)
                    source_sorted = np.sort(source_proj)
                    
                    # Match quantiles
                    source_ranks = np.argsort(np.argsort(source_proj))
                    quantile_targets = np.interp(
                        source_ranks / (len(source_proj) - 1), 
                        np.linspace(0, 1, len(target_sorted)), 
                        target_sorted
                    )
                    
                    # Compute particle-wise updates
                    particle_updates = (quantile_targets - source_proj)[:, np.newaxis] * proj
                    
                    # Apply updates to particles in this mode
                    mode_indices = np.where(mode_mask)[0]
                    for j, idx in enumerate(mode_indices):
                        if j < len(particle_updates):
                            reg_updates[idx] += particle_updates[j] * lambda_reg / len(mode_proj)
        else:
            # Standard global regularization
            for i, proj in enumerate(self.projections):
                # Project all particles
                target_proj = np.dot(target, proj)
                source_proj = np.dot(source, proj)
                
                # Sort projected values
                target_sorted = np.sort(target_proj)
                source_sorted = np.sort(source_proj)
                
                # Match quantiles
                source_ranks = np.argsort(np.argsort(source_proj))
                quantile_targets = np.interp(
                    source_ranks / (len(source_proj) - 1), 
                    np.linspace(0, 1, len(target_sorted)), 
                    target_sorted
                )
                
                # Compute updates
                updates = (quantile_targets - source_proj)[:, np.newaxis] * proj
                reg_updates += updates * lambda_reg / len(self.projections)
        
        return reg_updates


# ==========================================
# OptimizedESCORT Implementation - Unchanged
# ==========================================

class OptimizedESCORT(AdaptiveSVGD):
    """
    # & Advanced implementation of ESCORT with superior multi-modal handling
    # & and correlation structure preservation capabilities
    """
    # Keeping the original implementation as it already has all necessary functionality
    def __init__(self, kernel=None, gswd=None, step_size=0.02, 
                 n_iter=300, tol=1e-5, lambda_reg=0.3,  # Increased regularization weight
                 decay_step_size=True, verbose=True, 
                 noise_level=0.18, noise_decay=0.96,   # Better noise schedule
                 target_info=None, multi_stage=True,   # Enable multi-stage optimization
                 projection_directions=30):            # More projection directions
        # Create enhanced kernel if not provided
        if kernel is None:
            try:
                kernel = EnhancedRBFKernel(adaptive=True, min_bandwidth=0.08, max_bandwidth=12.0)
            except:
                kernel = RBFKernel(adaptive=True)
        
        # Create improved GSWD if not provided
        if gswd is None:
            try:
                gswd = ImprovedGSWD(
                    n_projections=projection_directions, 
                    projection_method='optimized', 
                    optimization_steps=10, 
                    correlation_aware=True,
                    mode_aware=True
                )
            except:
                gswd = GSWD(
                    n_projections=projection_directions,
                    projection_method='random', 
                    optimization_steps=5, 
                    correlation_aware=True
                )
        
        # Initialize parent class
        super().__init__(
            kernel=kernel, gswd=gswd, step_size=step_size,
            n_iter=n_iter, tol=tol, lambda_reg=lambda_reg,
            decay_step_size=decay_step_size, verbose=verbose
        )
        
        # Store enhanced parameters
        self.target_info = target_info
        self.detected_modes = None
        self.mode_assignments = None
        self.noise_level = noise_level
        self.noise_decay = noise_decay
        self.multi_stage = multi_stage
        
        # Cache for frequently used computations
        self._mode_centers = None
        self._mode_covs = None
        self._cholesky_cache = {}
        self._eigendecomposition_cache = {}
        
        # Adaptive annealing parameters
        self._annealing_schedule = None
        self._projection_weights = None
        self._mode_weights = None
        self._gswd_weight_schedule = None
        
        # Mode persistence tracking
        self._mode_persistence = None
        self._particle_history = []
        self._score_cache = {}
        
        # Metrics tracking
        self.metrics_history = {
            'gswd_loss': [],
            'mode_coverage': [],
            'correlation_error': []
        }


class OptimizedESCORTAdapter:
    """Adapter for OptimizedESCORT"""
    def __init__(self, n_iter=300, step_size=0.02, verbose=True, target_info=None):
        self.escort = OptimizedESCORT(
            step_size=step_size,
            n_iter=n_iter,
            verbose=verbose,
            noise_level=0.15,
            noise_decay=0.98,
            lambda_reg=0.2,
            target_info=target_info
        )
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        return self.escort.fit_transform(
            initial_particles, score_fn, target_samples, return_convergence)


# ========================================
# Target Distribution - Unchanged
# ========================================

class CorrelatedGMMDistribution(GMMDistribution):
    """
    # & GMM distribution with challenging correlation structure for testing
    """
    def __init__(self, name=None, seed=None):
        """
        # & Initialize the correlated GMM distribution
        # &
        # & Args:
        # &     name (str): Name for the distribution
        # &     seed (int): Random seed
        """
        # Define means, covariances, and weights for the 3-mode correlated GMM
        means = np.array([
            [-2.0, -2.0],    # Bottom-left with positive correlation
            [0.0, 0.0],      # Center with no correlation
            [2.0, 2.0]       # Top-right with negative correlation
        ])
        
        covs = np.array([
            # Bottom-left: Positive correlation
            [[1.0, 0.8], 
             [0.8, 1.0]],
            
            # Center: No correlation
            [[0.5, 0.0], 
             [0.0, 0.5]],
             
            # Top-right: Negative correlation
            [[1.0, -0.8], 
             [-0.8, 1.0]]
        ])
        
        # Slightly uneven weights to make the task more challenging
        weights = np.array([0.35, 0.3, 0.35])
        
        # Initialize base class
        super().__init__(means, covs, weights, name=name or "Correlated GMM", seed=seed)


# ========================================
# Evaluation Helper Functions - Statistical Enhanced
# ========================================

def create_target_distribution():
    """
    # & Create the target distribution for testing
    # &
    # & Returns:
    # &     CorrelatedGMMDistribution: Target distribution
    """
    return CorrelatedGMMDistribution(name="Correlated 3-Mode GMM")


def run_experiment_with_multiple_inits(methods_to_run=None, n_runs=5, seeds=None, 
                                      n_iter=300, step_size=0.02, verbose=True):
    """
    # & Run experiment comparing different methods with multiple initializations
    # &
    # & Args:
    # &     methods_to_run (list): Methods to evaluate
    # &     n_runs (int): Number of runs with different initializations
    # &     seeds (list): Random seeds for initializations
    # &     n_iter (int): Number of iterations
    # &     step_size (float): Step size for updates
    # &     verbose (bool): Whether to display progress
    # &
    # & Returns:
    # &     tuple: (mean_results_df, all_results_df, particles_dict, target_distribution)
    """
    print(f"Starting 2D GMM evaluation experiment with {n_runs} different initializations...")
    
    if methods_to_run is None:
        methods_to_run = ['ImprovedESCORT', 'ESCORT', 'EnhancedSVGD', 'SVGD', 'DVRL', 'SIR']
    
    # Generate random seeds if not provided
    if seeds is None:
        master_seed = np.random.randint(0, 10000)
        print(f"Master seed: {master_seed}")
        np.random.seed(master_seed)
        seeds = np.random.randint(0, 10000, size=n_runs)
    
    print(f"Using seeds: {seeds}")
    
    # Create target distribution (the same for all runs to ensure fair comparison)
    np.random.seed(42)  # Fixed seed for target distribution
    target_gmm = create_target_distribution()
    
    # Generate target samples (same for all runs)
    n_particles = 1000
    target_samples = target_gmm.sample(n_particles)
    
    # Create target info for improved methods
    target_info = {
        'n_modes': 3,
        'centers': target_gmm.means,
        'covs': [cov for cov in target_gmm.covs]
    }
    
    # Dictionary to store particles from all runs
    all_particles = {method: [] for method in methods_to_run}
    all_results = []
    
    # Run each method for each seed
    for run_idx, seed in enumerate(seeds):
        print(f"\n=== Run {run_idx+1}/{n_runs} (Initialization Seed: {seed}) ===")
        
        # Set random seed for this run's initialization
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
        
        # Create different initialization strategies for each run
        initialization_type = run_idx % 4  # Cycle through 4 different initialization strategies
        
        if initialization_type == 0:
            # Standard Gaussian initialization
            initial_particles = np.random.randn(n_particles, 2) * 2.0
            init_description = "Standard Gaussian"
        elif initialization_type == 1:
            # Uniform initialization
            initial_particles = np.random.uniform(-5, 5, (n_particles, 2))
            init_description = "Uniform [-5, 5]"
        elif initialization_type == 2:
            # Concentrated initialization around random mode
            mode_idx = np.random.choice(len(target_gmm.means))
            selected_mode = target_gmm.means[mode_idx]
            initial_particles = np.random.randn(n_particles, 2) * 0.5 + selected_mode
            init_description = f"Concentrated around mode {mode_idx+1}"
        else:
            # Multi-modal initialization
            modes = np.array([[-3.0, -3.0], [0.0, 1.0], [3.0, 3.0], [0.0, -2.0]])  # Different from target modes
            mode_idx = np.random.choice(len(modes), n_particles)
            initial_particles = modes[mode_idx] + np.random.randn(n_particles, 2) * 0.5
            init_description = "Custom multimodal"
            
        print(f"  Using {init_description} initialization")
        
        # Score function for target distribution
        score_fn = target_gmm.score
        
        # Create methods to evaluate for this run
        methods = {}
        particles_dict = {}
        convergence_dict = {}
        results_dict = {}
        
        # Add methods based on what's requested
        if 'ImprovedESCORT' in methods_to_run:
            methods['ImprovedESCORT'] = OptimizedESCORTAdapter(
                n_iter=n_iter, step_size=step_size, verbose=verbose, target_info=target_info)
        
        if 'ESCORT' in methods_to_run:
            try:
                from evaluate_2d_1 import ESCORTAdapter
                methods['ESCORT'] = ESCORTAdapter(n_iter=n_iter, step_size=step_size, verbose=verbose)
            except ImportError:
                print("Warning: ESCORTAdapter not found, using default AdaptiveSVGD")
                methods['ESCORT'] = AdaptiveSVGD(
                    step_size=step_size, n_iter=n_iter, verbose=verbose)
        
        # Add our EnhancedSVGD implementation
        if 'EnhancedSVGD' in methods_to_run:
            methods['EnhancedSVGD'] = EnhancedSVGDAdapter(
                n_iter=n_iter, 
                step_size=step_size, 
                verbose=verbose,
                target_info=target_info
            )
        
        if 'SVGD' in methods_to_run:
            # Use StableSVGDAdapter instead of standard SVGD
            methods['SVGD'] = StableSVGDAdapter(
                n_iter=n_iter, 
                step_size=step_size, 
                verbose=verbose,
                target_info=target_info
            )
        
        if 'DVRL' in methods_to_run:
            try:
                # Initialize the DVRL model
                dvrl = DVRL(
                    obs_dim=2,          # Observation dimension (2D state space)
                    action_dim=1,       # Simple 1D actions for testing
                    h_dim=64,           # Hidden state dimension
                    z_dim=2,            # Latent state dimension (matches state dimension)
                    n_particles=100,    # Use fewer particles for stability
                    continuous_actions=True  # Assuming continuous action space
                )
                
                # Explicitly move model to CPU before adapter creation
                dvrl = dvrl.to(torch.device('cpu'))
                
                # Create the adapter
                methods['DVRL'] = DVRLAdapter(dvrl, n_samples=n_particles)
            except Exception as e:
                print(f"Error initializing DVRL: {e}")
                methods['DVRL'] = None
        
        if 'SIR' in methods_to_run:
            methods['SIR'] = SIRAdapter(n_iter=1)  # Just one iteration for SIR
        
        # Run each method for this initialization
        for method_name, method in methods.items():
            if method is None:
                continue
                
            print(f"  Running {method_name}...")
            
            try:
                # Add small method-specific perturbations for this run
                method_particles = initial_particles.copy()
                method_particles += np.random.randn(*method_particles.shape) * 0.1
                
                start_time = time.time()
                particles, convergence = method.fit_transform(
                    method_particles, score_fn, target_samples, return_convergence=True)
                end_time = time.time()
                
                # Store results
                particles_dict[method_name] = particles
                convergence_dict[method_name] = convergence
                
                # Store particles from this run
                all_particles[method_name].append(particles)
                
                # Evaluate the method
                try:
                    evaluation = evaluate_method(
                        method_name, particles, target_gmm, target_samples, 
                        runtime=end_time - start_time)
                    
                    # Add run information
                    evaluation["Run"] = run_idx + 1
                    evaluation["Seed"] = seed
                    evaluation["Initialization"] = init_description
                    evaluation["Method"] = method_name
                    
                    # Add to all results
                    all_results.append(evaluation)
                except Exception as e:
                    print(f"Error evaluating {method_name}: {e}")
                    traceback.print_exc()
            except Exception as e:
                print(f"Error in {method_name}: {e}")
                traceback.print_exc()
    
    # Create DataFrame with all results
    all_results_df = pd.DataFrame(all_results)
    
    # Calculate mean and standard error for each metric and method
    metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
              'Mode Coverage', 'Correlation Error', 'ESS', 'Sliced Wasserstein', 'Runtime (s)']
    
    # Initialize dictionary for mean results
    mean_results = []
    
    # Calculate statistics for each method
    for method in methods_to_run:
        method_data = all_results_df[all_results_df["Method"] == method]
        
        if method_data.empty:
            continue
            
        # Calculate mean and standard error for each metric
        method_stats = {"Method": method}
        
        for metric in metrics:
            if metric in method_data.columns:
                values = method_data[metric].values
                mean_val = np.mean(values)
                se_val = sem(values)  # Standard error of the mean
                
                # Store mean and standard error
                method_stats[f"{metric}_mean"] = mean_val
                method_stats[f"{metric}_se"] = se_val
                method_stats[f"{metric}"] = f"{mean_val:.6f} ± {se_val:.6f}"
        
        mean_results.append(method_stats)
    
    # Convert to DataFrame
    mean_results_df = pd.DataFrame(mean_results)
    mean_results_df = mean_results_df.set_index('Method')
    
    # Print results
    print("\nResults Summary (Mean ± Standard Error):")
    display_cols = [metric for metric in metrics if metric in mean_results_df.columns]
    print(mean_results_df[display_cols])
    
    # Return the results
    return mean_results_df, all_results_df, all_particles, target_gmm


def visualize_results_with_error_bars(mean_results_df, all_results_df, all_particles, target_gmm):
    """
    # & Visualize experiment results with error bars
    # &
    # & Args:
    # &     mean_results_df: DataFrame with mean metrics and standard errors
    # &     all_results_df: DataFrame with results from all runs
    # &     all_particles: Dictionary of particles from each method and run
    # &     target_gmm: Target GMM distribution
    """
    # Create directory for plots
    plots_dir = os.path.join(SCRIPT_DIR, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Get methods to visualize
    methods = list(mean_results_df.index)
    
    # Plot target distribution
    plt.figure(figsize=(10, 8))
    
    # Plot target distribution contours
    x_min, x_max = -6, 6
    y_min, y_max = -6, 6
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100))
    grid_points = np.column_stack([xx.ravel(), yy.ravel()])
    
    # Evaluate log probabilities on grid for contour plot
    log_probs = target_gmm.log_prob(grid_points)
    probs = np.exp(log_probs).reshape(xx.shape)
    
    # Plot contours of target distribution
    plt.contour(xx, yy, probs, 15, colors='k', alpha=0.5, linewidths=0.5)
    
    # Generate some target samples for visualization
    np.random.seed(42)  # Fixed seed for consistency
    target_samples = target_gmm.sample(1000)
    plt.scatter(target_samples[:, 0], target_samples[:, 1], s=10, alpha=0.3, c='blue', label='Target Samples')
    
    # Plot component means and ellipses
    for i, (mean, cov) in enumerate(zip(target_gmm.means, target_gmm.covs)):
        # Plot mean
        plt.scatter(mean[0], mean[1], s=150, c='red', edgecolor='black', marker='*',
                   label=f'Mode {i+1}' if i == 0 else "")
        
        # Plot covariance ellipse (95% confidence region)
        eigvals, eigvecs = np.linalg.eigh(cov)
        idx = eigvals.argsort()[::-1]
        eigvals = eigvals[idx]
        eigvecs = eigvecs[:, idx]
        angle = np.arctan2(eigvecs[1, 0], eigvecs[0, 0])
        width = 2 * np.sqrt(5.991 * eigvals[0])
        height = 2 * np.sqrt(5.991 * eigvals[1])
        
        ellipse = Ellipse(xy=mean, width=width, height=height,
                        angle=np.degrees(angle), edgecolor='red', fc='none',
                        lw=2, alpha=0.7)
        plt.gca().add_patch(ellipse)
    
    # Set title and labels
    plt.title("Target 2D Correlated GMM Distribution", fontsize=14)
    plt.xlabel('x₁', fontsize=12)
    plt.ylabel('x₂', fontsize=12)
    plt.xlim([x_min, x_max])
    plt.ylim([y_min, y_max])
    plt.grid(alpha=0.3)
    plt.legend(loc='upper right')
    
    # Save figure
    plt.savefig(os.path.join(plots_dir, "target_2d_distribution.png"), dpi=300)
    plt.close()
    
    # Plot each method's particles
    for method_name in methods:
        # Create figure for this method's approximation
        plt.figure(figsize=(10, 8))
        
        # Plot target distribution contours
        plt.contour(xx, yy, probs, 15, colors='k', alpha=0.3, linewidths=0.5)
        
        # Plot particles from all runs with transparency
        method_particles_list = all_particles.get(method_name, [])
        
        if not method_particles_list:
            continue
            
        # Plot particles from each run with transparency to show variability
        for run_idx, particles in enumerate(method_particles_list):
            if run_idx == 0:
                # First run - more emphasis
                plt.scatter(particles[:, 0], particles[:, 1], s=15, alpha=0.6, c='red',
                           label=f'{method_name} Particles (Run 1)')
            else:
                # Other runs - more transparent
                plt.scatter(particles[:, 0], particles[:, 1], s=8, alpha=0.2, c='red')
        
        # Plot component means and ellipses for reference
        for i, (mean, cov) in enumerate(zip(target_gmm.means, target_gmm.covs)):
            # Plot mean
            plt.scatter(mean[0], mean[1], s=150, c='blue', edgecolor='black', marker='*',
                       label=f'Target Mode {i+1}' if i == 0 else "")
            
            # Plot covariance ellipse
            eigvals, eigvecs = np.linalg.eigh(cov)
            idx = eigvals.argsort()[::-1]
            eigvals = eigvals[idx]
            eigvecs = eigvecs[:, idx]
            angle = np.arctan2(eigvecs[1, 0], eigvecs[0, 0])
            width = 2 * np.sqrt(5.991 * eigvals[0])
            height = 2 * np.sqrt(5.991 * eigvals[1])
            
            ellipse = Ellipse(xy=mean, width=width, height=height,
                            angle=np.degrees(angle), edgecolor='blue', fc='none',
                            lw=2, alpha=0.5)
            plt.gca().add_patch(ellipse)
        
        # Get metrics with standard errors
        mode_coverage = mean_results_df.loc[method_name, 'Mode Coverage']
        correlation_error = mean_results_df.loc[method_name, 'Correlation Error']
        
        # Set title and labels
        plt.title(f"{method_name} Approximation\n{mode_coverage} (Mode Coverage), {correlation_error} (Corr. Error)", 
                 fontsize=14)
        plt.xlabel('x₁', fontsize=12)
        plt.ylabel('x₂', fontsize=12)
        plt.xlim([x_min, x_max])
        plt.ylim([y_min, y_max])
        plt.grid(alpha=0.3)
        plt.legend(loc='upper right')
        
        # Save figure
        plt.savefig(os.path.join(plots_dir, f"{method_name}_2d_approximation.png"), dpi=300)
        plt.close()
    
    # Plot metrics comparison with error bars
    plt.figure(figsize=(15, 10))
    
    # Extract metrics for plotting
    metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
              'Mode Coverage', 'Correlation Error', 'ESS', 'Sliced Wasserstein']
    
    # Create bar plots for each metric with error bars
    for i, metric in enumerate(metrics):
        if metric not in mean_results_df.columns:
            continue
            
        plt.subplot(2, 4, i+1)
        
        # Extract means and standard errors
        means = [mean_results_df.loc[method, f"{metric}_mean"] for method in methods 
                if f"{metric}_mean" in mean_results_df.loc[method]]
        errors = [mean_results_df.loc[method, f"{metric}_se"] for method in methods 
                 if f"{metric}_se" in mean_results_df.loc[method]]
        
        methods_with_data = [method for method in methods 
                           if f"{metric}_mean" in mean_results_df.loc[method]]
        
        if not methods_with_data:
            continue
        
        # Some metrics might need clipping for better visualization
        clip_upper = None
        if 'KL' in metric and max(means) > 5:
            clip_upper = 5.0
        elif 'Correlation Error' in metric and max(means) > 2:
            clip_upper = 2.0
        elif 'Sliced Wasserstein' in metric and max(means) > 5:
            clip_upper = 5.0
            
        if clip_upper:
            # Create mask for values to clip
            mask = np.array(means) <= clip_upper
            if not all(mask):
                # Add text annotation for clipped values
                for j, (m, val) in enumerate(zip(mask, means)):
                    if not m:
                        plt.text(j, clip_upper * 0.95, f"↑ {val:.2f}", ha='center', va='top')
                
                # Clip values
                means = [min(m, clip_upper) for m in means]
        
        # Create bar plot with error bars
        bars = plt.bar(methods_with_data, means, yerr=errors, capsize=5)
        
        # Add value annotations
        for j, bar in enumerate(bars):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + errors[j] + 0.01,
                    f'{means[j]:.4f}±{errors[j]:.4f}', ha='center', va='bottom', rotation=45,
                    fontsize=8)
        
        # Title and labels
        if metric in ['Mode Coverage', 'ESS']:
            plt.title(f"{metric} (higher is better)")
        else:
            plt.title(f"{metric} (lower is better)")
            
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        # For some metrics, set consistent y-limits
        if metric in ['Mode Coverage', 'ESS']:
            plt.ylim(0, 1.1)
    
    # Add runtime comparison
    plt.subplot(2, 4, 8)
    if 'Runtime (s)_mean' in mean_results_df.columns:
        runtime_means = [mean_results_df.loc[method, 'Runtime (s)_mean'] for method in methods 
                      if 'Runtime (s)_mean' in mean_results_df.loc[method]]
        runtime_errors = [mean_results_df.loc[method, 'Runtime (s)_se'] for method in methods 
                        if 'Runtime (s)_se' in mean_results_df.loc[method]]
        
        methods_with_runtime = [method for method in methods 
                              if 'Runtime (s)_mean' in mean_results_df.loc[method]]
        
        if methods_with_runtime:
            bars = plt.bar(methods_with_runtime, runtime_means, yerr=runtime_errors, capsize=5)
            
            # Add value annotations
            for j, bar in enumerate(bars):
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + runtime_errors[j] + 0.01,
                        f'{runtime_means[j]:.2f}s', ha='center', va='bottom', rotation=0,
                        fontsize=9)
                
            plt.title("Runtime (seconds)")
            plt.xticks(rotation=45)
            plt.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, "metrics_comparison_with_errors.png"), dpi=300)
    plt.close()
    
    # Create box plots for each metric to show distribution across runs
    plt.figure(figsize=(15, 10))
    
    for i, metric in enumerate(metrics):
        if metric not in all_results_df.columns:
            continue
            
        plt.subplot(2, 4, i+1)
        
        # Create list of data for box plot
        box_data = []
        box_labels = []
        
        for method in methods:
            method_data = all_results_df[all_results_df['Method'] == method]
            
            if not method_data.empty and metric in method_data.columns:
                box_data.append(method_data[metric].values)
                box_labels.append(method)
        
        if not box_data:
            continue
            
        # Create box plot
        plt.boxplot(box_data, labels=box_labels, patch_artist=True,
                  boxprops=dict(facecolor='lightblue', color='blue'),
                  whiskerprops=dict(color='blue'),
                  capprops=dict(color='blue'),
                  medianprops=dict(color='red'))
        
        # Title and labels
        plt.title(f"{metric} Distribution Across Different Initializations")
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        # For some metrics, set consistent y-limits
        if metric in ['Mode Coverage', 'ESS']:
            plt.ylim(0, 1.1)
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, "metrics_boxplots.png"), dpi=300)
    plt.close()
    
    # Create visualization by initialization type
    initializations = sorted(all_results_df['Initialization'].unique())
    if len(initializations) > 1:
        plt.figure(figsize=(15, 15))
        
        for i, metric in enumerate(['MMD', 'Mode Coverage', 'Correlation Error']):
            if metric not in all_results_df.columns:
                continue
                
            plt.subplot(3, 1, i+1)
            
            # Prepare data for grouped bar chart
            all_data = []
            all_labels = []
            all_positions = []
            all_errors = []
            
            bar_width = 0.15
            r = np.arange(len(initializations))
            
            for j, method in enumerate(methods):
                method_data = []
                method_errors = []
                
                for init in initializations:
                    df_subset = all_results_df[(all_results_df['Method'] == method) & 
                                             (all_results_df['Initialization'] == init)]
                    
                    if not df_subset.empty and metric in df_subset.columns:
                        method_data.append(df_subset[metric].mean())
                        method_errors.append(sem(df_subset[metric]))
                    else:
                        method_data.append(0)
                        method_errors.append(0)
                
                positions = [x + bar_width * j for x in r]
                
                all_data.append(method_data)
                all_errors.append(method_errors)
                all_positions.append(positions)
                all_labels.append(method)
            
            # Plot grouped bars
            colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown']
            
            for j, (data, positions, errors, label) in enumerate(zip(all_data, all_positions, all_errors, all_labels)):
                plt.bar(positions, data, width=bar_width, label=label, color=colors[j % len(colors)],
                      yerr=errors, capsize=3)
            
            # Add labels and legend
            plt.xlabel('Initialization Type')
            plt.ylabel(metric)
            plt.title(f'Performance by Initialization Type - {metric}')
            plt.xticks([r + bar_width * (len(methods) - 1) / 2 for r in range(len(initializations))], 
                      initializations, rotation=45)
            plt.legend()
            plt.grid(axis='y', alpha=0.3)
            
            # For mode coverage, set consistent y-limits
            if metric == 'Mode Coverage':
                plt.ylim(0, 1.1)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "metrics_by_initialization.png"), dpi=300)
        plt.close()
    
    # Create combined visualization for best methods
    # Select top 3 methods based on mean Mode Coverage * (1 - Correlation Error)
    if len(methods) >= 3 and 'Mode Coverage_mean' in mean_results_df.columns and 'Correlation Error_mean' in mean_results_df.columns:
        combined_score = {}
        for method in methods:
            mode_cov = mean_results_df.loc[method, 'Mode Coverage_mean']
            corr_err = min(1.0, mean_results_df.loc[method, 'Correlation Error_mean'])  # Cap at 1.0
            combined_score[method] = mode_cov * (1.0 - corr_err)
        
        # Sort methods by combined score
        best_methods = sorted(combined_score.items(), key=lambda x: x[1], reverse=True)[:3]
        best_method_names = [m[0] for m in best_methods]
        
        # Create combined figure
        fig, axes = plt.subplots(2, 2, figsize=(14, 12))
        fig.suptitle("Comparison of Top Methods for 2D Correlated Distribution", fontsize=16)
        
        # Target distribution (top-left)
        ax_target = axes[0, 0]
        ax_target.contour(xx, yy, probs, 15, colors='k', alpha=0.6, linewidths=0.8)
        ax_target.scatter(target_samples[:, 0], target_samples[:, 1], s=10, alpha=0.3, c='blue', label='Target Samples')
        
        # Plot component means and ellipses
        for i, (mean, cov) in enumerate(zip(target_gmm.means, target_gmm.covs)):
            # Plot mean
            ax_target.scatter(mean[0], mean[1], s=150, c='red', edgecolor='black', marker='*',
                           label=f'Mode {i+1}' if i == 0 else "")
            
            # Plot covariance ellipse
            eigvals, eigvecs = np.linalg.eigh(cov)
            idx = eigvals.argsort()[::-1]
            eigvals = eigvals[idx]
            eigvecs = eigvecs[:, idx]
            angle = np.arctan2(eigvecs[1, 0], eigvecs[0, 0])
            width = 2 * np.sqrt(5.991 * eigvals[0])
            height = 2 * np.sqrt(5.991 * eigvals[1])
            
            ellipse = Ellipse(xy=mean, width=width, height=height,
                            angle=np.degrees(angle), edgecolor='red', fc='none',
                            lw=2, alpha=0.7)
            ax_target.add_patch(ellipse)
        
        ax_target.set_title("Target Distribution", fontsize=14)
        ax_target.set_xlabel('x₁', fontsize=12)
        ax_target.set_ylabel('x₂', fontsize=12)
        ax_target.set_xlim([x_min, x_max])
        ax_target.set_ylim([y_min, y_max])
        ax_target.grid(alpha=0.3)
        ax_target.legend(loc='upper right')
        
        # Plot top 3 methods
        positions = [(0, 1), (1, 0), (1, 1)]
        colors = ['blue', 'green', 'red']
        
        for i, method_name in enumerate(best_method_names):
            row, col = positions[i]
            ax = axes[row, col]
            
            # Plot target contours
            ax.contour(xx, yy, probs, 15, colors='k', alpha=0.3, linewidths=0.5)
            
            # Plot particles from first run
            method_particles = all_particles.get(method_name, [])[0] if all_particles.get(method_name, []) else None
            
            if method_particles is not None:
                ax.scatter(method_particles[:, 0], method_particles[:, 1], s=15, alpha=0.6, c=colors[i],
                         label=f'{method_name}')
            
            # Add metrics to title
            mode_cov = mean_results_df.loc[method_name, 'Mode Coverage']
            corr_err = mean_results_df.loc[method_name, 'Correlation Error']
            score = combined_score[method_name]
            
            ax.set_title(f"{method_name}\nMode Cov: {mode_cov}\nCorr Err: {corr_err}\nScore: {score:.4f}", fontsize=12)
            ax.set_xlabel('x₁', fontsize=10)
            ax.set_ylabel('x₂', fontsize=10)
            ax.set_xlim([x_min, x_max])
            ax.set_ylim([y_min, y_max])
            ax.grid(alpha=0.3)
            ax.legend(loc='upper right')
        
        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust for suptitle
        plt.savefig(os.path.join(plots_dir, "top_methods_comparison.png"), dpi=300)
        plt.close()
    
    print(f"Visualizations saved to {plots_dir}")


# ========================================
# Main Function with Multiple Seeds
# ========================================

if __name__ == "__main__":
    import argparse
    
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Improved ESCORT Framework Evaluation with Multiple Seeds')
    parser.add_argument('--methods', nargs='+', 
                    default=['ImprovedESCORT', 'ESCORT', 'EnhancedSVGD', 'SVGD', 'DVRL', 'SIR'],
                    help='Methods to evaluate (default: ImprovedESCORT ESCORT EnhancedSVGD SVGD DVRL SIR)')
    parser.add_argument('--n_iter', type=int, default=300, 
                    help='Number of iterations (default: 300)')
    parser.add_argument('--step_size', type=float, default=0.02,
                    help='Step size for updates (default: 0.02)')
    parser.add_argument('--no_verbose', action='store_false', dest='verbose',
                    help='Disable verbose output (default: verbose enabled)')
    parser.add_argument('--n_runs', type=int, default=5,
                    help='Number of runs with different initializations (default: 5)')
    parser.add_argument('--fixed_seeds', action='store_true',
                    help='Use fixed seeds instead of random ones (default: False)')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Configure parameters
    method_params = {
        'n_iter': args.n_iter,
        'step_size': args.step_size,
        'verbose': args.verbose,
    }
    
    # Use fixed seeds if requested
    if args.fixed_seeds:
        seeds = [42, 123, 456, 789, 101]  # Fixed seeds for reproducibility
        seeds = seeds[:args.n_runs]  # Truncate if fewer runs requested
    else:
        seeds = None  # Generate random seeds
    
    # Run the experiment with specified methods and seeds
    mean_results_df, all_results_df, all_particles, target_gmm = run_experiment_with_multiple_inits(
        methods_to_run=args.methods,
        n_runs=args.n_runs,
        seeds=seeds,
        **method_params
    )
    
    # Save results to CSV
    results_dir = os.path.join(SCRIPT_DIR, "results")
    os.makedirs(results_dir, exist_ok=True)
    mean_results_df.to_csv(os.path.join(results_dir, "improved_escort_mean_results.csv"))
    all_results_df.to_csv(os.path.join(results_dir, "improved_escort_all_results.csv"))
    
    # Visualize the results with error bars
    visualize_results_with_error_bars(mean_results_df, all_results_df, all_particles, target_gmm)
    
    print("\nExperiment complete. Results saved to CSV and visualizations saved in the plots directory.")
    print(f"Results saved in: {results_dir}")
    print(f"Visualizations saved in: {os.path.join(SCRIPT_DIR, 'plots')}")
    
    # Report the initialization types that were most challenging for each method
    print("\nPerformance by Initialization Type:")
    for method in args.methods:
        method_data = all_results_df[all_results_df['Method'] == method]
        
        if method_data.empty:
            continue
        
        # For each initialization type, get the mean MMD
        init_performance = {}
        for init in method_data['Initialization'].unique():
            init_data = method_data[method_data['Initialization'] == init]
            
            # Use MMD as the primary metric (lower is better)
            if 'MMD' in init_data.columns:
                init_performance[init] = init_data['MMD'].mean()
        
        if not init_performance:
            continue
            
        # Sort by performance (lower MMD is better)
        sorted_inits = sorted(init_performance.items(), key=lambda x: x[1])
        
        print(f"\n{method}:")
        print(f"  Best performance on: {sorted_inits[0][0]} (MMD: {sorted_inits[0][1]:.6f})")
        print(f"  Worst performance on: {sorted_inits[-1][0]} (MMD: {sorted_inits[-1][1]:.6f})")
