"""
    Enhanced ESCORT Framework Evaluation on 3D Multi-modal Correlated Distribution
    
    This script evaluates the ESCORT framework against other methods on a 
    challenging 3D multi-modal distribution with complex correlation structures.
    
    Key features:
    1. 3D distribution with varied correlation patterns across modes
    2. Multiple random seeds for robust evaluation
    3. Statistical reporting with mean and standard error
    4. Visualization using GMMVisualizer's 3D capabilities
    5. Comparative analysis of ESCORT, SVGD, DVRL, and SIR methods
    6. Metrics for 3D distribution quality assessment
"""
import os
import sys
import time
import traceback
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from scipy.stats import multivariate_normal, sem
import pandas as pd
from tqdm import tqdm
from sklearn.cluster import KMeans
import torch

# Import required libraries from provided modules
from belief_assessment.distributions import GMMDistribution
from belief_assessment.evaluation.visualize_distributions import GMMVisualizer

# Try to import ESCORT-related classes and helper functions
try:
    from escort.utils.kernels import RBFKernel
    from escort.gswd import GSWD
    from escort.svgd import SVGD, AdaptiveSVGD
    from dvrl.dvrl import DVRL
    from tests.evaluate_2d_1 import SIRAdapter
    from tests.evaluate_2d_1 import compute_mmd, compute_ess, compute_mode_coverage_2d
    from tests.evaluate_2d_1 import compute_correlation_error, compute_sliced_wasserstein_distance
    from tests.evaluate_2d_1 import estimate_kl_divergence_2d, evaluate_method
except ImportError as e:
    print(f"Import error: {e}")
    print("Implementing required adapter classes and metrics...")

    # Define fallback adapter for SIR
    class SIRAdapter:
        """
        # & Adapter for Sequential Importance Resampling
        """
        def __init__(self, n_iter=1):
            self.n_iter = n_iter
            
        def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
            try:
                particles = initial_particles.copy()
                
                for iter in range(self.n_iter):
                    # Try to get log probabilities
                    try:
                        _, log_probs = score_fn(particles, return_logp=True)
                    except:
                        log_probs = score_fn(particles)
                    
                    # Compute weights
                    probs = np.exp(log_probs - np.max(log_probs))
                    weights = probs / np.sum(probs)
                    
                    # Compute ESS
                    ess = 1.0 / np.sum(weights**2)
                    ess_ratio = ess / len(particles)
                    
                    # Resample if ESS is too low
                    if ess_ratio < 0.5:
                        print(f"Iteration {iter}: Resampling (ESS = {ess_ratio:.4f})")
                        # Multinomial resampling
                        indices = np.random.choice(
                            len(particles), len(particles), p=weights, replace=True
                        )
                        particles = particles[indices]
                        
                        # Add small noise
                        particles += np.random.randn(*particles.shape) * 0.1
                
                if return_convergence:
                    return particles, {"iterations": self.n_iter}
                else:
                    return particles
            except Exception as e:
                print(f"Error in SIR: {e}")
                return initial_particles

# Get the directory of the current script for saving outputs
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))


class HighlyCorrelated3DGMMDistribution(GMMDistribution):
    """
    # & 3D GMM distribution with extremely challenging correlation structures
    # & designed to test correlation modeling capabilities
    """
    def __init__(self, name=None, seed=None):
        # Define 6 modes with varied and more extreme correlation patterns
        means = np.array([
            [-2.5, -2.5, -2.5],    # Mode 1: extreme XY correlation
            [2.5, -2.5, 2.5],      # Mode 2: extreme XZ correlation
            [-2.5, 2.5, 2.5],      # Mode 3: extreme YZ correlation
            [2.5, 2.5, -2.5],      # Mode 4: mixed positive/negative correlations
            [0.0, 0.0, 4.0],       # Mode 5: hierarchical correlations
            [0.0, 0.0, -4.0]       # Mode 6: highly elongated
        ])
        
        # Create more challenging correlation patterns
        covs = np.array([
            # Mode 1: Extreme XY correlation (0.95)
            [[1.2, 1.1, 0.0],
             [1.1, 1.2, 0.0],
             [0.0, 0.0, 0.3]],
            
            # Mode 2: Extreme XZ correlation (0.95) 
            [[1.2, 0.0, 1.1],
             [0.0, 0.3, 0.0],
             [1.1, 0.0, 1.2]],
             
            # Mode 3: Extreme YZ correlation (0.95)
            [[0.3, 0.0, 0.0],
             [0.0, 1.2, 1.1],
             [0.0, 1.1, 1.2]],
             
            # Mode 4: Complex mixed correlations (positive and negative)
            [[1.0, 0.7, -0.7],
             [0.7, 1.0, -0.7],
             [-0.7, -0.7, 1.0]],
             
            # Mode 5: Hierarchical correlations with different magnitudes
            [[2.0, 0.9, 0.4],
             [0.9, 1.0, 0.2],
             [0.4, 0.2, 0.5]],
             
            # Mode 6: Extremely elongated covariance (tests ability to model highly non-isotropic distributions)
            [[0.2, 0.0, 0.0],
             [0.0, 0.2, 0.0],
             [0.0, 0.0, 4.0]]
        ])
        
        # More varied weights to make the problem harder
        weights = np.array([0.2, 0.15, 0.15, 0.2, 0.15, 0.15])
        
        # Initialize base class
        super().__init__(means, covs, weights, name=name or "Highly Correlated 3D GMM", seed=seed)


# ========================================
# 3D Evaluation Metrics
# ========================================

def compute_mmd_3d(particles, target_samples, bandwidth=None):
    """
    # & Compute Maximum Mean Discrepancy between particles and target
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     target_samples (np.ndarray): Target distribution samples
    # &     bandwidth (float, optional): Kernel bandwidth
    # &
    # & Returns:
    # &     float: MMD value
    """
    # Use median heuristic if bandwidth not provided
    if bandwidth is None:
        # Compute pairwise distances for a subset of particles
        n_subset = min(1000, len(particles))
        subset_p = particles[:n_subset]
        
        dists = []
        for i in range(min(100, len(subset_p))):
            xi = subset_p[i]
            diff = subset_p - xi
            dists.extend(np.sum(diff**2, axis=1).tolist())
            
        if dists:
            bandwidth = np.median(dists)
        else:
            bandwidth = 1.0
    
    # RBF kernel function
    def kernel(x, y):
        return np.exp(-np.sum((x - y)**2) / bandwidth)
    
    # Compute MMD
    n_p = len(particles)
    n_t = len(target_samples)
    
    # Use subsampling for large datasets
    max_samples = 1000
    if n_p > max_samples:
        p_indices = np.random.choice(n_p, max_samples, replace=False)
        particles_sub = particles[p_indices]
        n_p = max_samples
    else:
        particles_sub = particles
        
    if n_t > max_samples:
        t_indices = np.random.choice(n_t, max_samples, replace=False)
        target_sub = target_samples[t_indices]
        n_t = max_samples
    else:
        target_sub = target_samples
    
    # Compute MMD terms
    pp_sum = 0
    for i in range(n_p):
        for j in range(i+1, n_p):
            pp_sum += kernel(particles_sub[i], particles_sub[j])
    pp_sum = 2 * pp_sum / (n_p * (n_p - 1))
    
    tt_sum = 0
    for i in range(n_t):
        for j in range(i+1, n_t):
            tt_sum += kernel(target_sub[i], target_sub[j])
    tt_sum = 2 * tt_sum / (n_t * (n_t - 1))
    
    pt_sum = 0
    for i in range(n_p):
        for j in range(n_t):
            pt_sum += kernel(particles_sub[i], target_sub[j])
    pt_sum = pt_sum / (n_p * n_t)
    
    mmd = pp_sum + tt_sum - 2 * pt_sum
    return max(0, mmd)  # Ensure non-negative


def compute_ess_3d(particles, score_fn):
    """
    # & Compute Effective Sample Size for particles
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     score_fn (callable): Score function
    # &
    # & Returns:
    # &     float: ESS value normalized to [0,1]
    """
    try:
        # Try to get log probabilities
        try:
            _, log_probs = score_fn(particles, return_logp=True)
        except:
            log_probs = score_fn(particles)
        
        # Compute weights
        max_log_prob = np.max(log_probs)
        probs = np.exp(log_probs - max_log_prob)
        weights = probs / np.sum(probs)
        
        # Compute ESS
        ess = 1.0 / np.sum(weights**2)
        return ess / len(particles)  # Normalized ESS
    except Exception as e:
        print(f"Error computing ESS: {e}")
        return 0.0


def compute_mode_coverage_3d(particles, gmm, threshold=0.1, mahalanobis_threshold=5.0):
    """
    # & Compute mode coverage ratio for 3D particles
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     gmm (GMMDistribution): Target GMM distribution
    # &     threshold (float): Coverage threshold
    # &
    # & Returns:
    # &     float: Mode coverage ratio in [0,1]
    """
    n_modes = gmm.n_components
    mode_centers = gmm.means
    
    # Determine if each mode is covered by particles
    modes_covered = np.zeros(n_modes, dtype=bool)
    
    # For each mode, check if there are enough particles nearby
    for i in range(n_modes):
        center = mode_centers[i]
        cov = gmm.covs[i]
        
        # Compute mode-specific adaptive distance threshold based on covariance properties
        # This adapts to elongated distributions
        try:
            eigvals = np.linalg.eigvalsh(cov)
            max_eigval = np.max(eigvals)
            # Scale threshold based on maximum eigenvalue (more elongated = larger threshold)
            mode_threshold = mahalanobis_threshold * (1.0 + np.log10(max_eigval / np.min(eigvals) + 1.0))
        except:
            # Fallback to default threshold
            mode_threshold = mahalanobis_threshold
            
        # Compute Mahalanobis distance
        try:
            # Try using Mahalanobis distance with covariance
            inv_cov = np.linalg.inv(cov + 1e-6 * np.eye(cov.shape[0]))  # More stable
            
            # Vectorized Mahalanobis calculation (more efficient)
            diff = particles - center
            distances = np.zeros(len(particles))
            for j in range(len(particles)):
                distances[j] = np.sqrt(diff[j] @ inv_cov @ diff[j])
        except:
            # Fallback to Euclidean distance
            diff = particles - center
            distances = np.sqrt(np.sum(diff**2, axis=1))
            # Use larger threshold for Euclidean 
            mode_threshold = 4.0
        
        # Count particles within adaptive threshold
        close_particles = np.sum(distances < mode_threshold)
        
        # Adaptive threshold requirement based on mode properties
        required_count = threshold * len(particles) / n_modes
        # For extremely elongated distributions, reduce the requirement
        if hasattr(gmm, 'covs') and i < len(gmm.covs):
            try:
                eigvals = np.linalg.eigvalsh(gmm.covs[i])
                if max(eigvals) / min(eigvals) > 5.0:  # Highly elongated
                    required_count *= 0.7  # Reduce requirement by 30%
            except:
                pass
                
        # Check if enough particles are near this mode
        if close_particles >= required_count:
            modes_covered[i] = True
    
    return np.mean(modes_covered)

def compute_correlation_error_3d(particles, gmm):
    """
    # & Compute error in capturing correlation structure
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     gmm (GMMDistribution): Target GMM distribution
    # &
    # & Returns:
    # &     float: Correlation structure error
    """
    # Cluster particles to identify modes
    n_modes = gmm.n_components
    
    # Skip if too few particles
    if len(particles) < n_modes * 5:
        return 1.0  # Maximum error
    
    try:
        # Use KMeans to cluster particles
        kmeans = KMeans(n_clusters=n_modes, random_state=42)
        cluster_labels = kmeans.fit_predict(particles)
        cluster_centers = kmeans.cluster_centers_
        
        # Match clusters to GMM modes (using Hungarian algorithm if available)
        try:
            from scipy.optimize import linear_sum_assignment
            
            # Compute distance matrix between cluster centers and GMM modes
            cost_matrix = np.zeros((n_modes, n_modes))
            for i in range(n_modes):
                for j in range(n_modes):
                    cost_matrix[i, j] = np.sum((cluster_centers[i] - gmm.means[j])**2)
            
            # Solve assignment problem
            row_ind, col_ind = linear_sum_assignment(cost_matrix)
            
            # Reorder cluster labels
            mode_map = {row_ind[i]: col_ind[i] for i in range(n_modes)}
            mode_labels = np.array([mode_map[label] for label in cluster_labels])
        except:
            # Fallback to simple nearest center assignment
            mode_labels = cluster_labels
        
        # Compute correlation error for each mode
        mode_errors = []
        
        for i in range(n_modes):
            # Get particles in this mode
            mode_particles = particles[mode_labels == i]
            
            if len(mode_particles) > 3:  # Need at least 4 particles for 3D covariance
                # Compute empirical covariance
                empirical_cov = np.cov(mode_particles, rowvar=False)
                
                # Get true covariance for this mode
                true_cov = gmm.covs[i]
                
                # Compute Frobenius norm of difference
                diff_norm = np.linalg.norm(empirical_cov - true_cov, 'fro')
                true_norm = np.linalg.norm(true_cov, 'fro')
                
                # Normalize error
                if true_norm > 1e-10:
                    mode_errors.append(diff_norm / true_norm)
                else:
                    mode_errors.append(1.0)
            else:
                mode_errors.append(1.0)  # Maximum error if too few particles
        
        return np.mean(mode_errors)
    except Exception as e:
        print(f"Error computing correlation error: {e}")
        return 1.0  # Maximum error on failure
    

def compute_correlation_error_3d_enhanced(particles, gmm):
    """
    # & Compute error in capturing correlation structure with greater sensitivity
    # & to extreme correlations
    """
    # Start with existing implementation
    basic_error = compute_correlation_error_3d(particles, gmm)
    
    # Add additional checks for extreme correlation capture
    n_modes = gmm.n_components
    
    # Skip if too few particles
    if len(particles) < n_modes * 5:
        return 1.0
    
    try:
        # Use KMeans to cluster particles
        kmeans = KMeans(n_clusters=n_modes, random_state=42)
        labels = kmeans.fit_predict(particles)
        
        # Compute correlation capture error with focus on extreme correlations
        corr_errors = []
        
        for i in range(n_modes):
            # Get particles assigned to this mode
            mode_particles = particles[labels == i]
            
            if len(mode_particles) > 10:  # Need enough samples for reliable correlation
                # Compute empirical correlation matrix
                empirical_corr = np.corrcoef(mode_particles, rowvar=False)
                
                # Get true correlation matrix
                true_cov = gmm.covs[i]
                d = np.sqrt(np.diag(true_cov))
                true_corr = true_cov / np.outer(d, d)
                
                # Focus on the strongest correlations
                mask = np.abs(true_corr) > 0.7
                
                # Mean squared error of strong correlations 
                if np.any(mask):
                    strong_corr_error = np.mean((empirical_corr[mask] - true_corr[mask])**2)
                    corr_errors.append(strong_corr_error)
        
        if corr_errors:
            # Combine with basic error, emphasizing extreme correlation errors
            return (basic_error + 2 * np.mean(corr_errors)) / 3
        else:
            return basic_error
        
    except Exception as e:
        print(f"Error in enhanced correlation metric: {e}")
        return basic_error


def compute_sliced_wasserstein_distance_3d(particles, target_samples, n_projections=20):
    """
    # & Compute Sliced Wasserstein Distance for 3D distributions
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     target_samples (np.ndarray): Target distribution samples
    # &     n_projections (int): Number of random projections
    # &
    # & Returns:
    # &     float: Sliced Wasserstein Distance
    """
    try:
        # Generate random projection directions
        directions = np.random.randn(n_projections, 3)
        directions = directions / np.linalg.norm(directions, axis=1, keepdims=True)
        
        # Compute Sliced Wasserstein Distance
        swd = 0.0
        
        for direction in directions:
            # Project samples onto this direction
            particles_proj = particles @ direction
            target_proj = target_samples @ direction
            
            # Sort projections
            particles_proj = np.sort(particles_proj)
            target_proj = np.sort(target_proj)
            
            # Compute 1-Wasserstein distance for this projection
            if len(particles_proj) != len(target_proj):
                # Interpolate to match lengths
                if len(particles_proj) > len(target_proj):
                    indices = np.linspace(0, len(target_proj)-1, len(particles_proj))
                    target_proj_interp = np.interp(indices, np.arange(len(target_proj)), target_proj)
                    w_dist = np.mean(np.abs(particles_proj - target_proj_interp))
                else:
                    indices = np.linspace(0, len(particles_proj)-1, len(target_proj))
                    particles_proj_interp = np.interp(indices, np.arange(len(particles_proj)), particles_proj)
                    w_dist = np.mean(np.abs(particles_proj_interp - target_proj))
            else:
                w_dist = np.mean(np.abs(particles_proj - target_proj))
            
            swd += w_dist
        
        return swd / n_projections
    except Exception as e:
        print(f"Error computing sliced Wasserstein distance: {e}")
        return float('inf')


def estimate_kl_divergence_3d(particles, gmm, direction='forward', n_bins=20):
    """
    # & Estimate KL divergence between particles and GMM
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     gmm (GMMDistribution): Target GMM distribution
    # &     direction (str): 'forward' for KL(p||q) or 'reverse' for KL(q||p)
    # &     n_bins (int): Number of bins for histogram estimation
    # &
    # & Returns:
    # &     float: Estimated KL divergence
    """
    # For 3D, we'll use a binned approach with marginal distributions
    # to make the estimation more tractable
    
    try:
        # Generate samples from the GMM for reverse KL
        if direction == 'reverse':
            gmm_samples = gmm.sample(len(particles))
            p_samples = gmm_samples
            q_samples = particles
        else:  # forward
            p_samples = particles
            q_samples = gmm.sample(len(particles))
        
        # Compute KL for each dimension and average
        kl_sum = 0.0
        
        for dim in range(3):
            # Extract 1D marginal distributions
            p_marginal = p_samples[:, dim]
            q_marginal = q_samples[:, dim]
            
            # Define bins for histograms
            all_samples = np.concatenate([p_marginal, q_marginal])
            bin_edges = np.linspace(np.min(all_samples), np.max(all_samples), n_bins+1)
            
            # Compute histograms
            p_hist, _ = np.histogram(p_marginal, bins=bin_edges, density=True)
            q_hist, _ = np.histogram(q_marginal, bins=bin_edges, density=True)
            
            # Add small epsilon to avoid division by zero or log(0)
            epsilon = 1e-10
            p_hist = p_hist + epsilon
            q_hist = q_hist + epsilon
            
            # Normalize histograms
            p_hist = p_hist / np.sum(p_hist)
            q_hist = q_hist / np.sum(q_hist)
            
            # Compute KL divergence
            kl_div = np.sum(p_hist * np.log(p_hist / q_hist))
            kl_sum += kl_div
        
        return kl_sum / 3.0  # Average across dimensions
    except Exception as e:
        print(f"Error estimating KL divergence: {e}")
        return float('inf')


def evaluate_method_3d(method_name, particles, gmm, target_samples, runtime=None):
    """
    # & Evaluate method performance using multiple metrics
    # &
    # & Args:
    # &     method_name (str): Name of the method
    # &     particles (np.ndarray): Particles from the method
    # &     gmm (GMMDistribution): Target GMM distribution
    # &     target_samples (np.ndarray): Target distribution samples
    # &     runtime (float, optional): Method runtime in seconds
    # &
    # & Returns:
    # &     dict: Evaluation metrics
    """
    results = {
        'Method': method_name,
        'MMD': compute_mmd_3d(particles, target_samples),
        'KL(Target||Method)': estimate_kl_divergence_3d(target_samples, gmm, direction='forward'),
        'KL(Method||Target)': estimate_kl_divergence_3d(particles, gmm, direction='reverse'),
        'Mode Coverage': compute_mode_coverage_3d(particles, gmm),
        'Correlation Error': compute_correlation_error_3d(particles, gmm),
        'ESS': compute_ess_3d(particles, gmm.score),
        'Sliced Wasserstein': compute_sliced_wasserstein_distance_3d(particles, target_samples),
    }
    
    if runtime is not None:
        results['Runtime (s)'] = runtime
    
    return results


# ========================================
# 3D Correlated GMM Distribution
# ========================================

class Correlated3DGMMDistribution(GMMDistribution):
    """
    # & 3D GMM distribution with challenging correlation structure for testing
    """
    def __init__(self, name=None, seed=None):
        """
        # & Initialize the correlated 3D GMM distribution
        # &
        # & Args:
        # &     name (str): Name for the distribution
        # &     seed (int): Random seed
        """
        # Define means, covariances, and weights for a 4-mode correlated GMM
        means = np.array([
            [-2.0, -2.0, -2.0],    # Mode 1 (bottom-left-front) with XY correlation
            [2.0, -2.0, 2.0],      # Mode 2 (bottom-right-back) with XZ correlation
            [-2.0, 2.0, 2.0],      # Mode 3 (top-left-back) with YZ correlation
            [2.0, 2.0, -2.0]       # Mode 4 (top-right-front) with XYZ correlation
        ])
        
        # Create different correlation patterns for each mode
        # FIXED: Modified correlation values to ensure positive definiteness
        covs = np.array([
            # Mode 1: XY correlation
            [[1.0, 0.6, 0.0],
             [0.6, 1.0, 0.0],
             [0.0, 0.0, 0.5]],
            
            # Mode 2: XZ correlation
            [[1.0, 0.0, 0.6],
             [0.0, 0.5, 0.0],
             [0.6, 0.0, 1.0]],
             
            # Mode 3: YZ correlation
            [[0.5, 0.0, 0.0],
             [0.0, 1.0, 0.6],
             [0.0, 0.6, 1.0]],
             
            # Mode 4: XYZ negative correlation - reduced to -0.4 from -0.5
            [[1.2, -0.4, -0.4],
             [-0.4, 1.2, -0.4],
             [-0.4, -0.4, 1.2]]
        ])
        
        # Slightly uneven weights
        weights = np.array([0.3, 0.2, 0.2, 0.3])
        
        # Initialize base class
        super().__init__(means, covs, weights, name=name or "Correlated 3D GMM", seed=seed)

    def is_positive_definite(cov):
        """Helper method to check if a matrix is positive definite"""
        try:
            np.linalg.cholesky(cov)
            return True
        except np.linalg.LinAlgError:
            return False


# ========================================
# 3D-compatible DVRL Adapter
# ========================================

class DVRLAdapter3D:
    """
    # & Adapter for DVRL to match interface with other methods for 3D data
    # & Fixed to work with the actual DVRL implementation
    """
    def __init__(self, dvrl_model, n_samples=1000):
        self.dvrl_model = dvrl_model
        self.n_samples = n_samples
        
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Adapts DVRL to the common interface
        # & 
        # & Args:
        # &     initial_particles (np.ndarray): Initial particles 
        # &     score_fn (callable): Score function
        # &     target_samples (np.ndarray, optional): Target distribution samples
        # &     return_convergence (bool): Whether to return convergence info
        # &
        # & Returns:
        # &     np.ndarray: Transformed particles
        # &     dict (optional): Convergence information if return_convergence=True
        """
        import numpy as np
        import torch
        
        dim = initial_particles.shape[1]  # Get dimensionality of initial particles
        particles = initial_particles.copy()
        
        try:
            # Try to find what methods are available in DVRL
            # Check if there's a sample or generate method
            if hasattr(self.dvrl_model, 'sample') and callable(getattr(self.dvrl_model, 'sample')):
                # If there's a sample method, try to use it
                try:
                    # Try different signatures that might exist
                    samples = self.dvrl_model.sample(n_samples=self.n_samples)
                    if isinstance(samples, torch.Tensor):
                        particles = samples.detach().cpu().numpy()
                except Exception as e:
                    print(f"Error using DVRL sample method: {e}")
                    # Fall back to a simple generation method
            elif hasattr(self.dvrl_model, 'generate') and callable(getattr(self.dvrl_model, 'generate')):
                # If there's a generate method, try to use it
                try:
                    samples = self.dvrl_model.generate(n_samples=self.n_samples)
                    if isinstance(samples, torch.Tensor):
                        particles = samples.detach().cpu().numpy()
                except Exception as e:
                    print(f"Error using DVRL generate method: {e}")
            elif hasattr(self.dvrl_model, 'forward') and callable(getattr(self.dvrl_model, 'forward')):
                # Try the forward method if it exists (common in PyTorch modules)
                try:
                    # Create a simple observation tensor
                    dummy_obs = torch.zeros((1, self.dvrl_model.obs_dim)).to(
                        next(self.dvrl_model.parameters()).device)
                    output = self.dvrl_model.forward(dummy_obs)
                    
                    # Handle different possible output formats
                    if isinstance(output, tuple) and len(output) > 0:
                        samples = output[0]  # Take first element if tuple
                    else:
                        samples = output
                        
                    if isinstance(samples, torch.Tensor):
                        # Reshape if needed to match expected particle count
                        if samples.shape[0] == 1:
                            # If batch dimension is 1, repeat to get desired number of samples
                            samples = samples.repeat(self.n_samples, 1)
                        particles = samples.detach().cpu().numpy()
                except Exception as e:
                    print(f"Error using DVRL forward method: {e}")
            
            # Check if dimensions match
            if particles.shape[1] != dim:
                print(f"Warning: DVRL output dimensions ({particles.shape[1]}) don't match input dimensions ({dim})")
                # Generate fallback particles based on input statistics
                particles = np.random.randn(self.n_samples, dim) * np.std(initial_particles, axis=0)
                particles += np.mean(initial_particles, axis=0)
                
            # Ensure we have the right number of particles
            if len(particles) != self.n_samples:
                # Resample to get correct count
                indices = np.random.choice(len(particles), self.n_samples, replace=len(particles) < self.n_samples)
                particles = particles[indices]
        
        except Exception as e:
            print(f"Failed to generate particles from DVRL, using fallback: {e}")
            # Use initial particles but add some noise to provide variety
            particles = initial_particles.copy()
            particles += np.random.randn(*particles.shape) * 0.1
        
        # Create convergence info
        if return_convergence:
            convergence_info = {
                'iterations': 1,
                'delta_norm_history': [0.0],  # Placeholder history
                'step_size_history': [0.0]    # Placeholder history
            }
            return particles, convergence_info
        else:
            return particles


# ========================================
# Stable SVGD for 3D implementation
# ========================================

class StableSVGD(SVGD):
    """
    # & Enhanced SVGD with aggressive mode exploration for 3D multi-modal distributions
    """
    def __init__(self, kernel=None, step_size=0.01, n_iter=300, tol=1e-5, 
                 bandwidth_scale=0.5, add_noise=True, noise_level=0.25, 
                 noise_decay=0.92, resample_freq=10, adaptive_step=True, 
                 mode_detection=True, lambda_corr=0.4, verbose=True,
                 target_info=None):
        """
        # & Initialize enhanced SVGD with better multi-modal support
        """
        # Create default kernel if not provided - with adaptive bandwidth
        if kernel is None:
            kernel = RBFKernel(bandwidth=1.0, adaptive=True)
            
        # Initialize parent class
        super().__init__(kernel=kernel, step_size=step_size, 
                          n_iter=n_iter, tol=tol, verbose=verbose)
        
        # Enhanced parameters for better mode coverage
        self.bandwidth_scale = bandwidth_scale
        self.add_noise = add_noise
        self.noise_level = noise_level  # Increased noise level
        self.noise_decay = noise_decay  # Slower decay
        self.resample_freq = resample_freq  # More frequent resampling
        self.adaptive_step = adaptive_step
        self.mode_detection = mode_detection
        self.lambda_corr = lambda_corr  # Higher correlation weight
        
        # Store target information if available
        self.target_info = target_info
        
        # Initialize mode-related attributes
        self.detected_modes = None
        self.mode_assignments = None
        self._mode_centers = None
        self._mode_covs = None
        self._cholesky_cache = {}
        
        # More aggressive mode balancing
        self.mode_balance_freq = 8  # More frequent balancing
        self.direct_intervention_freq = 10  # More frequent intervention
        self.missing_mode_threshold = 0.01  # Lower threshold for detecting missing modes
        
        # Enhanced correlation handling
        self.use_mahalanobis = True  # Use Mahalanobis distance for mode assignment
        self.correlation_scale = 1.1  # Scaling for correlation-aware updates
        
        # Keep track of iterations run
        self.iterations_run = 0
        
    def initialize_particles(self, particles):
        """
        # & Initialize particles with more uniform mode coverage
        """
        n_particles, dim = particles.shape
        
        # If we have target info with known modes, use it
        if self.target_info is not None and 'centers' in self.target_info:
            centers = self.target_info['centers']
            covs = self.target_info.get('covs', None)
            
            # Store mode information for later use
            self._mode_centers = centers
            self._mode_covs = covs
            
            # Number of modes
            n_modes = len(centers)
            
            # Distribute particles evenly across all modes
            new_particles = np.zeros((n_particles, dim))
            
            # More uniform distribution than before (like ESCORT)
            particles_per_mode = [n_particles // n_modes] * n_modes
            
            # Account for rounding
            remainder = n_particles - sum(particles_per_mode)
            for i in range(remainder):
                particles_per_mode[i] += 1
            
            # Initialize mode assignments
            self.mode_assignments = np.zeros(n_particles, dtype=int)
            
            idx = 0
            for i in range(n_modes):
                n_mode = particles_per_mode[i]
                
                # Initialize particles around this mode with correlation structure
                if covs is not None and i < len(covs):
                    try:
                        cov = covs[i]
                        # Add small regularization for numerical stability
                        cov_reg = cov + 1e-5 * np.eye(dim)
                        
                        # Try Cholesky first
                        try:
                            L = np.linalg.cholesky(cov_reg)
                            self._cholesky_cache[i] = L
                            # Generate correlated random samples
                            z = np.random.randn(n_mode, dim)
                            # Use tighter spread - closer to ESCORT now
                            correlated = z @ L.T * 0.6  
                            new_particles[idx:idx+n_mode] = centers[i] + correlated
                        except:
                            # If Cholesky fails, use eigendecomposition
                            eigvals, eigvecs = np.linalg.eigh(cov_reg)
                            eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive
                            L = eigvecs @ np.diag(np.sqrt(eigvals))
                            z = np.random.randn(n_mode, dim)
                            correlated = z @ L.T * 0.6
                            new_particles[idx:idx+n_mode] = centers[i] + correlated
                    except:
                        # Ultimate fallback to isotropic
                        new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.6
                else:
                    # Use isotropic normal if no covariance available
                    new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.6
                
                # Assign mode labels
                self.mode_assignments[idx:idx+n_mode] = i
                
                idx += n_mode
            
            # Add extra exploration noise to a subset of particles
            explore_fraction = 0.1  # 10% of particles get extra noise
            explore_count = int(n_particles * explore_fraction)
            if explore_count > 0:
                explore_indices = np.random.choice(n_particles, explore_count, replace=False)
                new_particles[explore_indices] += np.random.randn(explore_count, dim) * 1.5
            
            return new_particles
        
        # If no target info, just return original particles
        return particles
    
    def _update_mode_assignments(self, particles):
        """
        # & Update mode assignments using Mahalanobis distance when possible
        """
        if self._mode_centers is None:
            return None, None
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        
        # Initialize or reset mode assignments
        if self.mode_assignments is None or len(self.mode_assignments) != n_particles:
            self.mode_assignments = np.zeros(n_particles, dtype=int)
        
        # Use Mahalanobis distance when covariance is available
        if self.use_mahalanobis and self._mode_covs is not None:
            distances = np.zeros((n_particles, n_modes))
            
            for i, center in enumerate(self._mode_centers):
                if i < len(self._mode_covs):
                    try:
                        # Add regularization for numerical stability
                        cov = self._mode_covs[i] + 1e-5 * np.eye(particles.shape[1])
                        inv_cov = np.linalg.inv(cov)
                        
                        # Compute Mahalanobis distance
                        diff = particles - center
                        for j in range(n_particles):
                            distances[j, i] = np.sqrt(diff[j] @ inv_cov @ diff[j])
                    except:
                        # Fallback to Euclidean
                        diff = particles - center
                        distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
                else:
                    # Use Euclidean if no covariance
                    diff = particles - center
                    distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
        else:
            # Use Euclidean distance
            distances = np.zeros((n_particles, n_modes))
            for i, center in enumerate(self._mode_centers):
                diff = particles - center
                distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
        
        # Assign each particle to nearest center
        self.mode_assignments = np.argmin(distances, axis=1)
        
        return self._mode_centers, self.mode_assignments
    
    def _direct_mode_intervention(self, particles, iteration):
        """
        # & More aggressive mode intervention for better mode coverage
        """
        if self._mode_centers is None:
            return particles
            
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        
        # Update mode assignments first
        self._update_mode_assignments(particles)
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Check for severely underrepresented modes (using more aggressive threshold)
        target_per_mode = n_particles / n_modes
        
        # More aggressive threshold than before - catch modes with very few particles
        critically_low = np.where(mode_counts < target_per_mode * self.missing_mode_threshold)[0]
        
        if len(critically_low) > 0:
            # More particles to move - up to 20% of expected count for severely underrep modes
            particles_to_move = min(int(target_per_mode * 0.2), 
                                   int(n_particles * 0.05))  # but never more than 5% of total
            
            for mode_idx in critically_low:
                # Find particles to replace from most populated modes
                most_pop_mode = np.argmax(mode_counts)
                if most_pop_mode == mode_idx:
                    # Find next most populated mode
                    temp_counts = mode_counts.copy()
                    temp_counts[most_pop_mode] = 0
                    if np.sum(temp_counts) == 0:
                        continue  # No other modes have particles
                    most_pop_mode = np.argmax(temp_counts)
                    
                source_indices = np.where(self.mode_assignments == most_pop_mode)[0]
                
                # Number of particles to move - more aggressive
                n_move = min(particles_to_move, len(source_indices), 
                             int(mode_counts[most_pop_mode] - target_per_mode * 0.5))
                
                if n_move > 0:
                    # Select indices to replace
                    move_indices = source_indices[:n_move]
                    
                    # Place at mode center with appropriate correlation-aware noise
                    mode_center = self._mode_centers[mode_idx]
                    
                    # Add correlated noise if available
                    if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                        cov = self._mode_covs[mode_idx]
                        try:
                            # Try Cholesky decomposition
                            if mode_idx in self._cholesky_cache:
                                L = self._cholesky_cache[mode_idx]
                            else:
                                # Add regularization
                                cov_reg = cov + 1e-5 * np.eye(dim)
                                L = np.linalg.cholesky(cov_reg)
                                self._cholesky_cache[mode_idx] = L
                                
                            # Generate correlated samples
                            for i, idx in enumerate(move_indices):
                                particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.5
                                self.mode_assignments[idx] = mode_idx
                        except:
                            # Try eigendecomposition if Cholesky fails
                            try:
                                eigvals, eigvecs = np.linalg.eigh(cov + 1e-5 * np.eye(dim))
                                eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive eigenvalues
                                L = eigvecs @ np.diag(np.sqrt(eigvals))
                                
                                # Generate correlated samples
                                for i, idx in enumerate(move_indices):
                                    z = np.random.randn(dim)
                                    correlated_noise = L @ z * 0.5
                                    particles[idx] = mode_center + correlated_noise
                                    self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to isotropic noise
                                for i, idx in enumerate(move_indices):
                                    particles[idx] = mode_center + np.random.randn(dim) * 0.5
                                    self.mode_assignments[idx] = mode_idx
                    else:
                        # Use isotropic noise if no covariance
                        for i, idx in enumerate(move_indices):
                            particles[idx] = mode_center + np.random.randn(dim) * 0.5
                            self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[most_pop_mode] -= n_move
                    mode_counts[mode_idx] += n_move
        
        # Also check for moderately underrepresented modes (higher threshold)
        moderately_low = np.where((mode_counts >= target_per_mode * self.missing_mode_threshold) & 
                                 (mode_counts < target_per_mode * 0.4))[0]
        
        if len(moderately_low) > 0 and iteration < self.n_iter * 0.5:
            # For moderately underrepresented, move fewer particles but still intervene
            particles_to_move = int(target_per_mode * 0.1)  # 10% of expected
            
            for mode_idx in moderately_low:
                # Only move particles if we're sufficiently below target
                if mode_counts[mode_idx] < target_per_mode * 0.4:
                    # Find particles from most populated mode
                    most_pop_mode = np.argmax(mode_counts)
                    if most_pop_mode == mode_idx or mode_counts[most_pop_mode] < target_per_mode * 1.2:
                        continue  # Skip if not enough excess particles
                        
                    source_indices = np.where(self.mode_assignments == most_pop_mode)[0]
                    
                    # Move a smaller number
                    n_move = min(particles_to_move, len(source_indices),
                                int(mode_counts[most_pop_mode] - target_per_mode))
                    
                    if n_move > 0:
                        # Select indices and move with correlated noise
                        move_indices = source_indices[:n_move]
                        mode_center = self._mode_centers[mode_idx]
                        
                        if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                            try:
                                # Use cached Cholesky if available
                                if mode_idx in self._cholesky_cache:
                                    L = self._cholesky_cache[mode_idx]
                                    for i, idx in enumerate(move_indices):
                                        particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.5
                                        self.mode_assignments[idx] = mode_idx
                                else:
                                    # Try eigendecomposition
                                    eigvals, eigvecs = np.linalg.eigh(self._mode_covs[mode_idx] + 1e-5 * np.eye(dim))
                                    eigvals = np.maximum(eigvals, 1e-6)
                                    L = eigvecs @ np.diag(np.sqrt(eigvals))
                                    for i, idx in enumerate(move_indices):
                                        z = np.random.randn(dim)
                                        particles[idx] = mode_center + L @ z * 0.5
                                        self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to isotropic
                                for i, idx in enumerate(move_indices):
                                    particles[idx] = mode_center + np.random.randn(dim) * 0.5
                                    self.mode_assignments[idx] = mode_idx
                        else:
                            # Use isotropic noise
                            for i, idx in enumerate(move_indices):
                                particles[idx] = mode_center + np.random.randn(dim) * 0.5
                                self.mode_assignments[idx] = mode_idx
                        
                        # Update mode counts
                        mode_counts[most_pop_mode] -= n_move
                        mode_counts[mode_idx] += n_move
        
        return particles
    
    def _compute_svgd_update(self, particles, score_fn, iteration=0):
        """
        # & Compute SVGD update with improved correlation handling
        """
        n_particles, dim = particles.shape
        
        # Update mode assignments more frequently
        if self.mode_detection and (iteration == 0 or iteration % 5 == 0):
            self._update_mode_assignments(particles)
        
        # Get score function values with error handling
        try:
            score_values = score_fn(particles)
            
            # Check for NaN or inf values
            if np.any(np.isnan(score_values)) or np.any(np.isinf(score_values)):
                score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
        except:
            # If score function fails, use zero scores
            score_values = np.zeros_like(particles)
        
        # Enhanced correlation guidance using eigendecomposition
        if self._mode_covs is not None and self.mode_assignments is not None:
            # Pre-compute eigendecompositions for efficiency
            mode_eig_cache = {}
            
            for mode_idx, cov in enumerate(self._mode_covs):
                if mode_idx not in mode_eig_cache:
                    try:
                        eigvals, eigvecs = np.linalg.eigh(cov)
                        eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive
                        mode_eig_cache[mode_idx] = (eigvals, eigvecs)
                    except:
                        # Identity fallback
                        mode_eig_cache[mode_idx] = (np.ones(dim), np.eye(dim))
            
            # Apply scaled correlation guidance to each particle
            for mode_idx in np.unique(self.mode_assignments):
                if mode_idx >= len(self._mode_covs):
                    continue
                    
                # Get particles in this mode
                mode_mask = self.mode_assignments == mode_idx
                mode_indices = np.where(mode_mask)[0]
                
                if len(mode_indices) > 0 and mode_idx in mode_eig_cache:
                    eigvals, eigvecs = mode_eig_cache[mode_idx]
                    
                    # Apply correlation-aware scaling to score values
                    for idx in mode_indices:
                        # Project score onto eigenvectors
                        proj_score = eigvecs.T @ score_values[idx]
                        
                        # Scale by sqrt of eigenvalues with correlation emphasis
                        # Using higher power to emphasize correlation differences
                        scaling = np.sqrt(eigvals) ** self.correlation_scale
                        proj_score = proj_score * scaling
                        
                        # Project back to original space
                        score_values[idx] = eigvecs @ proj_score
        
        # Compute kernel matrix and gradient
        K = self.kernel.evaluate(particles)
        grad_K = self.kernel.gradient(particles)
        
        # Handle numerical issues
        if np.any(np.isnan(K)) or np.any(np.isinf(K)):
            K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
        if np.any(np.isnan(grad_K)) or np.any(np.isinf(grad_K)):
            grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Compute attractive forces
        attractive = np.zeros_like(particles)
        for i in range(n_particles):
            attractive[i] = np.sum(K[i, :, np.newaxis] * score_values, axis=0)
        
        # Compute repulsive forces
        repulsive = np.zeros_like(particles)
        for i in range(n_particles):
            repulsive[i] = np.sum(grad_K[:, i, :], axis=0)
        
        # Dynamic repulsion factor - stronger early on
        if iteration < self.n_iter * 0.2:
            # Very strong early repulsion
            repulsion_factor = 2.2 * (1.0 - iteration / (self.n_iter * 0.2) * 0.3)
        elif iteration < self.n_iter * 0.5:
            # Moderate middle-stage repulsion
            repulsion_factor = 1.8
        else:
            # Gentler late-stage repulsion
            repulsion_factor = 1.5
        
        # Apply repulsion factor
        repulsive *= repulsion_factor
        
        # More aggressive mode balancing
        if self.mode_detection and self.mode_assignments is not None:
            unique_modes = np.unique(self.mode_assignments)
            mode_counts = np.bincount(self.mode_assignments, minlength=len(unique_modes))
            
            if len(mode_counts) > 0 and np.any(mode_counts > 0):
                # Compute target counts
                target_counts = np.ones_like(mode_counts) * (n_particles / len(unique_modes))
                
                # More aggressive weighting for severe imbalances
                mode_weights = (target_counts / np.maximum(mode_counts, 1)) ** 1.8
                
                # Cap weights to avoid numerical issues
                mode_weights = np.clip(mode_weights, 0.5, 8.0)
                
                # Apply weights based on mode assignment
                for mode_idx, weight in enumerate(mode_weights):
                    mode_mask = self.mode_assignments == mode_idx
                    attractive[mode_mask] *= weight
                    
                    # Scale repulsion inversely in very low-count modes to avoid dispersion
                    if mode_counts[mode_idx] < target_counts[mode_idx] * 0.2:
                        repulsive[mode_mask] *= min(1.0, weight * 0.5)
                    else:
                        repulsive[mode_mask] *= weight
        
        # Combine attractive and repulsive terms
        update = attractive + repulsive
        
        # Handle numerical issues
        if np.any(np.isnan(update)) or np.any(np.isinf(update)):
            update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Clip extreme updates relative to average particle distance
        # This improves stability while allowing meaningful updates
        try:
            # Sample-based estimate of average particle distance
            sample_size = min(50, n_particles)
            if sample_size > 1:
                indices = np.random.choice(n_particles, sample_size, replace=False)
                dists = []
                for i in range(sample_size):
                    for j in range(i+1, sample_size):
                        dists.append(np.linalg.norm(particles[indices[i]] - particles[indices[j]]))
                avg_dist = np.mean(dists) if dists else 1.0
            else:
                avg_dist = 1.0
                
            # Clip updates relative to this distance
            max_norm = avg_dist * 0.15  # Allow up to 15% movement relative to avg distance
            update_norms = np.sqrt(np.sum(update**2, axis=1))
            large_updates = update_norms > max_norm
            if np.any(large_updates):
                scale_factors = max_norm / update_norms[large_updates]
                update[large_updates] *= scale_factors[:, np.newaxis]
        except:
            # Fallback to simple clipping if estimation fails
            update_norms = np.sqrt(np.sum(update**2, axis=1))
            max_norm = 1.0
            large_updates = update_norms > max_norm
            if np.any(large_updates):
                scale_factors = max_norm / update_norms[large_updates]
                update[large_updates] *= scale_factors[:, np.newaxis]
        
        return update
    
    def _mode_balanced_resample(self, particles):
        """
        # & Enhanced mode-balancing resampling with correlation preservation
        """
        if self._mode_centers is None or self.mode_assignments is None:
            return particles
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        new_particles = particles.copy()
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Target count per mode - uniform distribution
        target_count = n_particles / n_modes
        
        # Find significantly underrepresented modes
        # More aggressive threshold than before (60% of target) - closer to ESCORT
        for mode_idx in range(n_modes):
            # Check if this mode is significantly underrepresented
            if mode_counts[mode_idx] < target_count * 0.6:
                mode_deficit = int(target_count - mode_counts[mode_idx])
                
                # Generate new particles around this mode center
                mode_center = self._mode_centers[mode_idx]
                
                # Find particles from overrepresented modes to replace
                other_modes = np.where(mode_counts > target_count * 1.2)[0]
                if len(other_modes) > 0:
                    # Get particles from most overrepresented mode
                    replace_mode = other_modes[np.argmax(mode_counts[other_modes])]
                    replace_indices = np.where(self.mode_assignments == replace_mode)[0]
                    
                    # Replace a subset of these particles
                    n_replace = min(mode_deficit, len(replace_indices))
                    replace_indices = replace_indices[:n_replace]
                    
                    # Generate new particles with correlation structure
                    if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                        cov = self._mode_covs[mode_idx]
                        try:
                            # Try using cached Cholesky decomposition
                            if mode_idx in self._cholesky_cache:
                                L = self._cholesky_cache[mode_idx]
                            else:
                                # Compute new decomposition with regularization
                                L = np.linalg.cholesky(cov + 1e-5 * np.eye(dim))
                                self._cholesky_cache[mode_idx] = L
                                
                            # Generate correlated samples
                            for i, idx in enumerate(replace_indices):
                                new_particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.5
                                self.mode_assignments[idx] = mode_idx
                        except:
                            # Try eigendecomposition if Cholesky fails
                            try:
                                eigvals, eigvecs = np.linalg.eigh(cov + 1e-5 * np.eye(dim))
                                eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive eigenvalues
                                L = eigvecs @ np.diag(np.sqrt(eigvals))
                                
                                # Generate correlated samples
                                for i, idx in enumerate(replace_indices):
                                    noise = np.random.randn(dim)
                                    new_particles[idx] = mode_center + (L @ noise) * 0.5
                                    self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to isotropic noise
                                for i, idx in enumerate(replace_indices):
                                    new_particles[idx] = mode_center + np.random.randn(dim) * 0.5
                                    self.mode_assignments[idx] = mode_idx
                    else:
                        # Use isotropic noise if no covariance information
                        for i, idx in enumerate(replace_indices):
                            new_particles[idx] = mode_center + np.random.randn(dim) * 0.5
                            self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[replace_mode] -= n_replace
                    mode_counts[mode_idx] += n_replace
        
        return new_particles
    
    def update(self, particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Enhanced SVGD optimization with aggressive mode exploration
        """
        # Make a copy of initial particles
        particles = particles.copy()
        n_particles, dim = particles.shape
        
        # Improved initialization
        particles = self.initialize_particles(particles)
        
        # Tracking variables
        curr_step_size = self.step_size
        current_noise = self.noise_level if self.add_noise else 0.0
        delta_norm_history = []
        step_size_history = []
        
        # Set up progress bar if verbose
        iterator = range(self.n_iter)
        if self.verbose:
            try:
                iterator = tqdm(iterator, desc="Stable SVGD 3D")
            except ImportError:
                pass
        
        # Force early mode intervention to ensure coverage
        if self._mode_centers is not None:
            particles = self._direct_mode_intervention(particles, 0)
        
        # Main optimization loop
        for t in iterator:
            # Frequent direct mode intervention in early iterations
            if t < self.n_iter * 0.5 and t % self.direct_intervention_freq == 0:
                particles = self._direct_mode_intervention(particles, t)
                
            # Compute SVGD update
            update = self._compute_svgd_update(particles, score_fn, t)
            
            # Enhanced noise schedule with exploration phase
            if current_noise > 0:
                if t < self.n_iter * 0.2:
                    # Very strong early exploration noise
                    noise_scale = current_noise * 1.5
                    noise = np.random.randn(*particles.shape) * noise_scale
                    update = update + noise
                elif t < self.n_iter * 0.4:
                    # Moderate noise in middle phase
                    noise_scale = current_noise * 0.8
                    if t % 2 == 0:  # Every other iteration
                        noise = np.random.randn(*particles.shape) * noise_scale
                        update = update + noise
                elif t < self.n_iter * 0.7:
                    # Light noise in later phase
                    noise_scale = current_noise * 0.4
                    if t % 4 == 0:  # Every fourth iteration
                        noise = np.random.randn(*particles.shape) * noise_scale
                        update = update + noise
                
                # Slower noise decay in early iterations
                if t < self.n_iter * 0.2:
                    current_noise *= self.noise_decay ** 0.5  # Slower decay
                else:
                    current_noise *= self.noise_decay
            
            # Apply update with step size
            new_particles = particles + curr_step_size * update
            
            # More frequent mode-balancing resampling in early iterations
            if t > 0:
                if t < self.n_iter * 0.3 and t % (self.mode_balance_freq // 2) == 0:
                    # Very frequent in early phase
                    new_particles = self._mode_balanced_resample(new_particles)
                elif t < self.n_iter * 0.7 and t % self.mode_balance_freq == 0:
                    # Regular frequency in middle phase
                    new_particles = self._mode_balanced_resample(new_particles)
                elif t % (self.mode_balance_freq * 2) == 0:
                    # Less frequent in late phase
                    new_particles = self._mode_balanced_resample(new_particles)
            
            # Compute convergence metric with safety checks
            delta = new_particles - particles
            delta_norm = np.linalg.norm(delta) / (n_particles * dim)
            
            # Handle numerical instability
            if np.isnan(delta_norm) or np.isinf(delta_norm) or delta_norm > 1e10:
                curr_step_size *= 0.1
                if self.verbose:
                    print(f"Unstable update detected! Reducing step size to {curr_step_size:.6f}")
                continue
            
            # Record history
            delta_norm_history.append(delta_norm)
            step_size_history.append(curr_step_size)
            
            # Update particles
            particles = new_particles
            
            # Improved step size decay schedule
            if self.adaptive_step:
                if t < self.n_iter * 0.3:
                    # Maintain larger steps early for exploration
                    curr_step_size = self.step_size / (1.0 + 0.005 * t)
                else:
                    # Faster decay later for refinement
                    curr_step_size = self.step_size / (1.0 + 0.02 * t)
            
            # Check for convergence
            if t > self.n_iter // 2 and delta_norm < self.tol:
                if self.verbose:
                    print(f"Converged after {t+1} iterations. Delta norm: {delta_norm:.6f}")
                self.iterations_run = t + 1
                break
        else:
            self.iterations_run = self.n_iter
            if self.verbose:
                print(f"Maximum iterations reached. Final delta norm: {delta_norm:.6f}")
        
        if return_convergence:
            convergence_info = {
                'delta_norm_history': np.array(delta_norm_history),
                'step_size_history': np.array(step_size_history),
                'iterations_run': self.iterations_run
            }
            return particles, convergence_info
        
        return particles


class StableSVGDAdapter:
    """
    # & Adapter for enhanced StableSVGD with improved multi-modal support
    """
    def __init__(self, n_iter=300, step_size=0.01, verbose=True, target_info=None):
        self.svgd = StableSVGD(
            step_size=step_size,
            n_iter=n_iter,
            verbose=verbose,
            add_noise=True,
            noise_level=0.25,          # Much higher noise for exploration
            noise_decay=0.92,          # Slower decay to maintain exploration
            resample_freq=10,          # More frequent resampling
            adaptive_step=True,
            mode_detection=True,
            lambda_corr=0.4,           # Higher correlation emphasis
            target_info=target_info
        )
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        return self.svgd.fit_transform(
            initial_particles, score_fn, target_samples, return_convergence)


# ========================================
# ESCORT 3D Implementation
# ========================================

class ESCORT3D(AdaptiveSVGD):
    """
    # & 3D implementation of ESCORT framework for POMDPs
    """
    def __init__(self, kernel=None, gswd=None, step_size=0.02, 
                n_iter=300, tol=1e-5, lambda_reg=0.3,  # Increased lambda_reg for better correlation
                decay_step_size=True, verbose=True, 
                noise_level=0.2, noise_decay=0.95,  # Higher noise, slower decay
                target_info=None):
        # Create default kernel if not provided
        if kernel is None:
            kernel = RBFKernel(adaptive=True)
        
        # Create default GSWD if not provided - more projections and steps
        if gswd is None:
            gswd = GSWD(n_projections=40, projection_method='random', 
                    optimization_steps=10, correlation_aware=True)
        
        # Initialize parent class
        super().__init__(
            kernel=kernel, gswd=gswd, step_size=step_size,
            n_iter=n_iter, tol=tol, lambda_reg=lambda_reg,
            decay_step_size=decay_step_size, verbose=verbose
        )
        
        # Store target information if available
        self.target_info = target_info
        self.detected_modes = None
        self.mode_assignments = None
        self.noise_level = noise_level
        self.noise_decay = noise_decay
        
        # Cache frequently used data
        self._mode_centers = None
        self._mode_covs = None
        self._cholesky_cache = {}
        
        # Store iterations run
        self.iterations_run = 0
        
        # New: Add mode balancing frequency
        self.mode_balance_freq = 10
        # New: Add more aggressive exploration
        self.aggressive_exploration = True
        # New: Stronger early repulsion
        self.repulsion_factor = 2.5
    
    def _initialize_particles(self, particles):
        """Optimized particle initialization for 3D with better mode coverage"""
        n_particles, dim = particles.shape
        
        # Only optimize if we have target info
        if self.target_info is not None and 'centers' in self.target_info:
            centers = self.target_info['centers']
            covs = self.target_info.get('covs', None)
            
            # Cache mode information
            self._mode_centers = centers
            self._mode_covs = covs
            
            # Allocate particles to modes
            n_modes = len(centers)
            new_particles = np.zeros_like(particles)
            
            # Better distribution across modes - slightly bias toward complex modes
            particles_per_mode = []
            total_allocated = 0
            
            for i in range(n_modes):
                # Allocate more particles to modes with stronger correlation
                if covs is not None and i < len(covs):
                    # Compute correlation strength using off-diagonal elements
                    cov = covs[i]
                    off_diag_sum = np.sum(np.abs(cov - np.diag(np.diag(cov))))
                    corr_factor = 1.0 + off_diag_sum / (dim * dim)
                    # Allocate 10-30% more to complex correlation modes
                    mode_count = int(n_particles / n_modes * min(1.3, max(1.1, corr_factor)))
                else:
                    mode_count = n_particles // n_modes
                
                particles_per_mode.append(mode_count)
                total_allocated += mode_count
            
            # Adjust to match exactly n_particles
            while total_allocated > n_particles:
                idx = np.argmax(particles_per_mode)
                particles_per_mode[idx] -= 1
                total_allocated -= 1
                
            while total_allocated < n_particles:
                idx = np.argmin(particles_per_mode)
                particles_per_mode[idx] += 1
                total_allocated += 1
            
            idx = 0
            for i in range(n_modes):
                n_mode = particles_per_mode[i]
                
                # Generate correlated samples efficiently
                try:
                    # Cache Cholesky decompositions
                    if i not in self._cholesky_cache and covs is not None and i < len(covs):
                        # Add small regularization for numerical stability
                        cov_reg = covs[i] + 1e-5 * np.eye(dim)
                        try:
                            L = np.linalg.cholesky(cov_reg)
                            self._cholesky_cache[i] = L
                        except:
                            # If Cholesky fails, use eigendecomposition
                            eigvals, eigvecs = np.linalg.eigh(cov_reg)
                            eigvals = np.maximum(eigvals, 1e-5)  # Ensure positive eigenvalues
                            L = eigvecs @ np.diag(np.sqrt(eigvals))
                            self._cholesky_cache[i] = L
                    
                    if i in self._cholesky_cache:
                        L = self._cholesky_cache[i]
                        # Use vectorized operations with appropriate scaling
                        z = np.random.randn(n_mode, dim)
                        # Scale for better initial clustering
                        correlated = np.dot(z, L.T) * 0.5
                        new_particles[idx:idx+n_mode] = centers[i] + correlated
                    else:
                        # If no cached decomposition, use isotropic
                        new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.5
                except:
                    # Simple fallback
                    new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.5
                
                idx += n_mode
            
            # Initialize mode assignments
            self.mode_assignments = np.zeros(n_particles, dtype=int)
            start_idx = 0
            for i, n_mode in enumerate(particles_per_mode):
                self.mode_assignments[start_idx:start_idx+n_mode] = i
                start_idx += n_mode
            
            self.detected_modes = centers
            
            return new_particles
        
        return particles
    
    def _update_mode_assignments(self, particles):
        """Efficiently update mode assignments for 3D with Mahalanobis distance"""
        n_particles = len(particles)
        
        # If we already have mode centers, assign to nearest considering correlation
        if self._mode_centers is not None:
            # Initialize or reset mode assignments if needed
            if self.mode_assignments is None or len(self.mode_assignments) != n_particles:
                self.mode_assignments = np.zeros(n_particles, dtype=int)
            
            # Use Mahalanobis distance when possible for better correlation-aware assignment
            if self._mode_covs is not None:
                # For each particle, compute Mahalanobis distance to each mode
                distances = np.zeros((n_particles, len(self._mode_centers)))
                
                for i, center in enumerate(self._mode_centers):
                    if i < len(self._mode_covs):
                        try:
                            # Try to compute inverse covariance
                            cov = self._mode_covs[i]
                            # Add small regularization for numerical stability
                            cov_reg = cov + 1e-5 * np.eye(cov.shape[0])
                            inv_cov = np.linalg.inv(cov_reg)
                            
                            # Compute Mahalanobis distance for all particles (vectorized)
                            diff = particles - center
                            # Batch matrix multiplication for all particles
                            for j in range(n_particles):
                                d = diff[j]
                                distances[j, i] = np.sqrt(d @ inv_cov @ d)
                        except:
                            # Fallback to Euclidean if inverse fails
                            diff = particles - center
                            distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
                    else:
                        # Fallback to Euclidean for modes without covariance
                        diff = particles - center
                        distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
            else:
                # Use Euclidean distance if no covariance info
                distances = np.zeros((n_particles, len(self._mode_centers)))
                for i, center in enumerate(self._mode_centers):
                    diff = particles - center
                    distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
            
            # Assign to closest center
            self.mode_assignments = np.argmin(distances, axis=1)
            
            return True
        
        return False
    
    def _compute_svgd_update(self, particles, score_fn, iteration=0):
        """Compute SVGD update with better correlation awareness and mode balancing for 3D"""
        n_particles, dim = particles.shape
        
        # Update mode assignments periodically - more frequently in early iterations
        update_freq = max(5, min(10, self.n_iter // 30))
        if iteration == 0 or (iteration % update_freq == 0 and self._mode_centers is not None):
            self._update_mode_assignments(particles)
        
        # Get score values with better error handling
        try:
            score_values = score_fn(particles)
            
            # Check for NaN/Inf and replace with zeros
            if np.any(np.isnan(score_values)) or np.any(np.isinf(score_values)):
                score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
        except:
            # If score function fails, use zero scores
            score_values = np.zeros_like(particles)
        
        # Apply correlation guidance and mode-specific updates
        if self._mode_covs is not None and self.mode_assignments is not None:
            # Compute eigendecompositions for each mode once
            mode_eigen_cache = {}
            
            # Pre-compute eigendecompositions for each mode covariance
            for mode_idx, cov in enumerate(self._mode_covs):
                if mode_idx not in mode_eigen_cache:
                    try:
                        # Compute eigendecomposition once per mode
                        eigvals, eigvecs = np.linalg.eigh(cov)
                        # Ensure positivity for numerical stability
                        eigvals = np.maximum(eigvals, 1e-6)
                        mode_eigen_cache[mode_idx] = (eigvals, eigvecs)
                    except:
                        # If decomposition fails, use identity
                        mode_eigen_cache[mode_idx] = (np.ones(dim), np.eye(dim))
            
            # Apply correlation-aware updates for each mode
            for mode_idx in np.unique(self.mode_assignments):
                if mode_idx >= len(self._mode_covs):
                    continue
                    
                # Get particles in this mode
                mode_mask = self.mode_assignments == mode_idx
                mode_indices = np.where(mode_mask)[0]
                
                if len(mode_indices) > 0:
                    # Get eigenvectors and eigenvalues for this mode
                    if mode_idx in mode_eigen_cache:
                        eigvals, eigvecs = mode_eigen_cache[mode_idx]
                        
                        # Apply correlation-aware gradient scaling
                        for idx in mode_indices:
                            # Project score onto eigenvectors
                            proj_score = np.dot(eigvecs.T, score_values[idx])
                            
                            # Scale by sqrt of eigenvalues - stronger effect
                            scale_factor = 1.2  # Amplify correlation effect
                            proj_score = proj_score * (np.sqrt(eigvals) ** scale_factor)
                            
                            # Project back
                            score_values[idx] = np.dot(eigvecs, proj_score)
        
        # Compute kernel matrix and gradient
        K = self.kernel.evaluate(particles)
        grad_K = self.kernel.gradient(particles)
        
        # Check for NaN/Inf values and clean up
        if np.any(np.isnan(K)) or np.any(np.isinf(K)):
            K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
        if np.any(np.isnan(grad_K)) or np.any(np.isinf(grad_K)):
            grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Compute attractive forces
        attractive = np.zeros_like(particles)
        for i in range(n_particles):
            attractive[i] = np.sum(K[i, :, np.newaxis] * score_values, axis=0)
        
        # Compute repulsive forces
        repulsive = np.zeros_like(particles)
        for i in range(n_particles):
            repulsive[i] = np.sum(grad_K[:, i, :], axis=0)
        
        # Dynamic repulsion factor based on iteration - stronger early on
        # More aggressive in 3D to combat mode collapse
        if iteration < self.n_iter * 0.3:
            # Very strong repulsion in early iterations
            repulsion_factor = self.repulsion_factor * (1.0 - 0.5 * iteration / (self.n_iter * 0.3))
            repulsion_factor = max(2.0, repulsion_factor)
        elif iteration < self.n_iter * 0.6:
            # Moderate repulsion in middle iterations
            repulsion_factor = 1.8
        else:
            # Lower repulsion in final iterations for refinement
            repulsion_factor = 1.2
        
        # Apply repulsion factor
        repulsive *= repulsion_factor
        
        # Dynamic mode balancing - more aggressive in 3D
        if self.mode_assignments is not None:
            unique_modes = np.unique(self.mode_assignments)
            mode_counts = np.bincount(self.mode_assignments, minlength=len(unique_modes))
            
            if len(mode_counts) > 0 and np.any(mode_counts > 0):
                # Compute ideal count
                ideal_count = n_particles / len(unique_modes)
                
                # More aggressive balancing with higher power
                mode_weights = (ideal_count / np.maximum(mode_counts, 1)) ** 2.0
                # Cap weights to avoid extreme values
                mode_weights = np.clip(mode_weights, 0.5, 5.0)
                
                # Apply weights based on mode assignment
                particle_weights = mode_weights[self.mode_assignments]
                
                # Apply weights to forces
                attractive *= particle_weights[:, np.newaxis]
                repulsive *= particle_weights[:, np.newaxis]
        
        # Combine forces
        update = attractive + repulsive
        
        # Normalize update to avoid extremely large steps
        update_norm = np.linalg.norm(update)
        if update_norm > 1e-10:  # Avoid division by zero
            avg_particle_dist = 0
            if n_particles > 1:
                # Estimate average distance between particles
                sample_size = min(n_particles, 100)
                sampled_indices = np.random.choice(n_particles, sample_size, replace=False)
                dists = []
                for i in range(sample_size):
                    for j in range(i+1, sample_size):
                        dists.append(np.linalg.norm(particles[sampled_indices[i]] - particles[sampled_indices[j]]))
                avg_particle_dist = np.mean(dists) if dists else 1.0
            else:
                avg_particle_dist = 1.0
                
            # Scale update relative to particle distances
            scale_factor = avg_particle_dist * 0.1  # Allow 10% movement relative to average distance
            
            # Apply update scaling if norm is too large
            if update_norm > scale_factor:
                update = update * (scale_factor / update_norm)
        
        return update
    
    def _mode_balanced_resample(self, particles):
        """Enhanced mode-aware resampling for 3D"""
        if self._mode_centers is None or self.mode_assignments is None:
            return particles
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        new_particles = particles.copy()
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Target count per mode (uniform distribution)
        target_count = n_particles / n_modes
        
        # Find underrepresented modes - use a lower threshold for 3D (60%)
        for mode_idx in range(n_modes):
            # If this mode has significantly too few particles
            if mode_counts[mode_idx] < target_count * 0.6:
                mode_deficit = int(target_count - mode_counts[mode_idx])
                
                # Generate new particles around this mode center
                mode_center = self._mode_centers[mode_idx]
                
                # Try to use covariance structure if available
                cov = None
                if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                    cov = self._mode_covs[mode_idx]
                
                # Find particles from overrepresented modes to replace
                other_modes = np.where(mode_counts > target_count * 1.2)[0]
                if len(other_modes) > 0:
                    # Get particles from most overrepresented mode
                    replace_mode = other_modes[np.argmax(mode_counts[other_modes])]
                    replace_indices = np.where(self.mode_assignments == replace_mode)[0]
                    
                    # Replace a subset of these particles
                    n_replace = min(mode_deficit, len(replace_indices))
                    replace_indices = replace_indices[:n_replace]
                    
                    # Generate new particles for underrepresented mode
                    for i, idx in enumerate(replace_indices):
                        if cov is not None:
                            try:
                                # Use cached Cholesky if available
                                if mode_idx in self._cholesky_cache:
                                    L = self._cholesky_cache[mode_idx]
                                else:
                                    # Generate correlated sample
                                    L = np.linalg.cholesky(cov + 1e-5 * np.eye(dim))
                                    self._cholesky_cache[mode_idx] = L
                                
                                # Create correlated noise
                                new_particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.5
                            except:
                                # If Cholesky fails, use eigendecomposition
                                try:
                                    eigvals, eigvecs = np.linalg.eigh(cov)
                                    eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive eigenvalues
                                    L = eigvecs @ np.diag(np.sqrt(eigvals))
                                    noise = np.random.randn(dim)
                                    # Create correlated noise
                                    new_particles[idx] = mode_center + (L @ noise) * 0.5
                                except:
                                    # Fallback to isotropic
                                    new_particles[idx] = mode_center + np.random.randn(dim) * 0.5
                        else:
                            # Isotropic normal if no covariance
                            new_particles[idx] = mode_center + np.random.randn(dim) * 0.5
                        
                        # Update mode assignment
                        self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[replace_mode] -= n_replace
                    mode_counts[mode_idx] += n_replace
        
        return new_particles
    

    def _direct_mode_intervention(self, particles, iteration):
        """Directly intervene to maintain mode coverage in difficult cases"""
        if self._mode_centers is None:
            return particles
            
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        
        # Only apply direct intervention periodically and early in optimization
        if iteration > self.n_iter * 0.5:
            return particles
            
        # Count particles per mode using current assignments
        if self.mode_assignments is not None and len(self.mode_assignments) == n_particles:
            mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
            
            # Check if any mode has less than 10% of expected count
            target_per_mode = n_particles / n_modes
            critically_low = np.where(mode_counts < target_per_mode * 0.1)[0]
            
            if len(critically_low) > 0:
                # Direct intervention needed - place particles directly at mode centers
                particles_to_move = int(target_per_mode * 0.2)  # Move 20% of expected count
                
                for mode_idx in critically_low:
                    # Find particles to replace - preferably from overrepresented modes
                    overrep_modes = np.where(mode_counts > target_per_mode * 1.5)[0]
                    
                    if len(overrep_modes) > 0:
                        source_mode = overrep_modes[0]
                        source_indices = np.where(self.mode_assignments == source_mode)[0]
                        
                        # Number of particles to move
                        n_move = min(particles_to_move, len(source_indices))
                        
                        if n_move > 0:
                            # Select indices to replace
                            move_indices = source_indices[:n_move]
                            
                            # Place directly at mode center with appropriate noise
                            mode_center = self._mode_centers[mode_idx]
                            
                            # Add correlation-aware noise if available
                            if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                                cov = self._mode_covs[mode_idx]
                                try:
                                    # Use cached Cholesky or compute it
                                    if mode_idx in self._cholesky_cache:
                                        L = self._cholesky_cache[mode_idx]
                                    else:
                                        L = np.linalg.cholesky(cov + 1e-5 * np.eye(dim))
                                        self._cholesky_cache[mode_idx] = L
                                        
                                    # Generate correlated samples
                                    for i, idx in enumerate(move_indices):
                                        particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.4
                                        # Update mode assignment
                                        self.mode_assignments[idx] = mode_idx
                                except:
                                    # Fallback to isotropic
                                    for i, idx in enumerate(move_indices):
                                        particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                        # Update mode assignment
                                        self.mode_assignments[idx] = mode_idx
                            else:
                                # Use isotropic noise
                                for i, idx in enumerate(move_indices):
                                    particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                    # Update mode assignment
                                    self.mode_assignments[idx] = mode_idx
                            
                            # Update mode counts
                            mode_counts[source_mode] -= n_move
                            mode_counts[mode_idx] += n_move
        
        return particles
    
    def update(self, particles, score_fn, target_samples=None, return_convergence=False):
        """Run enhanced ESCORT optimization for 3D"""
        # Initialize
        particles = particles.copy()
        n_particles, dim = particles.shape
        
        # Better initialization with mode coverage
        particles = self._initialize_particles(particles)
        
        # Tracking variables
        delta_norm_history = []
        step_size_history = []
        curr_step_size = self.step_size
        current_noise = self.noise_level
        
        # Prepare GSWD if target samples are provided
        if target_samples is not None and self.lambda_reg > 0:
            self.gswd.fit(target_samples, particles)
        
        # Setup progress bar
        iterator = range(self.n_iter)
        if self.verbose:
            try:
                iterator = tqdm(iterator, desc="ESCORT 3D")
            except ImportError:
                pass
        
        # Main update loop
        for t in iterator:
            # Aggressive early exploration
            if t < self.n_iter * 0.2 and t % 20 == 0:
                # Directly enforce mode coverage periodically
                particles = self._direct_mode_intervention(particles, t)
            
            # Compute SVGD update with better correlation awareness
            svgd_update = self._compute_svgd_update(particles, score_fn, t)
            
            # Add GSWD regularization - apply more frequently in 3D
            if target_samples is not None and self.lambda_reg > 0 and (t % 3 == 0):
                try:
                    # Update gswd regularization
                    gswd_reg = self.gswd.get_regularizer(target_samples, particles, self.lambda_reg)
                    update = svgd_update + gswd_reg
                except Exception as e:
                    update = svgd_update
            else:
                update = svgd_update
            
            # Add noise with better decaying schedule for 3D
            if current_noise > 0:
                if t < self.n_iter * 0.3:
                    # Strong noise early on - every iteration
                    noise = np.random.randn(*particles.shape) * current_noise * 1.5
                    update = update + noise
                elif t % 3 == 0 and t < self.n_iter * 0.6:
                    # Moderate noise in middle phase - every 3 iterations
                    noise = np.random.randn(*particles.shape) * current_noise * 0.8
                    update = update + noise
                elif t % 5 == 0 and t < self.n_iter * 0.8:
                    # Light noise in later phase - less frequent
                    noise = np.random.randn(*particles.shape) * current_noise * 0.5
                    update = update + noise
                
                # Better noise decay schedule for 3D
                if t < self.n_iter * 0.3:
                    # Slower decay in early iterations
                    current_noise *= self.noise_decay ** 0.5
                else:
                    # Normal decay later
                    current_noise *= self.noise_decay
            
            # Apply update
            new_particles = particles + curr_step_size * update
            
            # Mode-based resampling - more frequent in 3D
            if t > 0 and t % self.mode_balance_freq == 0:
                # First update mode assignments
                self._update_mode_assignments(new_particles)
                # Then rebalance
                new_particles = self._mode_balanced_resample(new_particles)
            
            # Check convergence
            delta = new_particles - particles
            delta_norm = np.linalg.norm(delta) / n_particles
            delta_norm_history.append(delta_norm)
            step_size_history.append(curr_step_size)
            
            # Update particles
            particles = new_particles
            
            # Step size decay - gentler for 3D
            if self.decay_step_size:
                if t < self.n_iter * 0.3:
                    # Maintain larger steps initially
                    curr_step_size = self.step_size / (1.0 + 0.005 * t)
                else:
                    # Moderate decay later
                    curr_step_size = self.step_size / (1.0 + 0.01 * t)
            
            # Early stopping with less frequent checking for efficiency
            if t > self.n_iter * 0.7 and t % 10 == 0 and delta_norm < self.tol:
                if self.verbose:
                    print(f"Converged after {t+1} iterations. Delta norm: {delta_norm:.6f}")
                self.iterations_run = t + 1
                break
        
        # Update iterations run if didn't break early
        else:
            self.iterations_run = self.n_iter
            if self.verbose:
                print(f"Maximum iterations reached. Final delta norm: {delta_norm:.6f}")
        
        if return_convergence:
            convergence_info = {
                'delta_norm_history': np.array(delta_norm_history),
                'step_size_history': np.array(step_size_history),
                'iterations_run': self.iterations_run
            }
            return particles, convergence_info
        
        return particles
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, 
                     return_convergence=False, reset=True):
        """Run the optimizer on initial particles"""
        if reset:
            self.detected_modes = None
            self.mode_assignments = None
            self._cholesky_cache = {}
            
            # Set up mode information from target_info
            if self.target_info is not None:
                self._mode_centers = self.target_info.get('centers', None)
                self._mode_covs = self.target_info.get('covs', None)
            else:
                self._mode_centers = None
                self._mode_covs = None
        
        return self.update(initial_particles, score_fn, target_samples, return_convergence)


class ESCORT3DAdapter:
    """
    # & Adapter for ESCORT3D to match interface with other methods
    """
    def __init__(self, n_iter=300, step_size=0.02, verbose=True, target_info=None):
        self.escort = ESCORT3D(
            step_size=step_size,
            n_iter=n_iter,
            verbose=verbose,
            noise_level=0.15,
            noise_decay=0.98,
            lambda_reg=0.2,
            target_info=target_info
        )
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        return self.escort.fit_transform(
            initial_particles, score_fn, target_samples, return_convergence)


# ========================================
# Multi-seed Experiment Functions
# ========================================

def run_experiment(methods_to_run=None, n_iter=300, step_size=0.01, verbose=True, seed=None):
    """
    # & Run experiment comparing different methods with a specific seed
    # &
    # & Args:
    # &     methods_to_run (list): Methods to evaluate
    # &     n_iter (int): Number of iterations
    # &     step_size (float): Step size for updates
    # &     verbose (bool): Whether to display progress
    # &     seed (int): Random seed for initialization
    # &
    # & Returns:
    # &     tuple: (results_df, target_distribution, particles_dict, convergence_dict)
    """
    if verbose:
        print(f"Running experiment with seed {seed}...")
    
    if methods_to_run is None:
        methods_to_run = ['ESCORT3D', 'SVGD', 'DVRL', 'SIR']
    
    # Set random seed if provided
    if seed is not None:
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
        torch.manual_seed(seed)
    
    # Create target distribution
    target_gmm = HighlyCorrelated3DGMMDistribution()
    
    # Generate target samples
    n_particles = 1000
    target_samples = target_gmm.sample(n_particles)
    
    # Create initial particles with the given seed
    initial_particles = np.random.randn(n_particles, 3) * 3
    
    # Score function for target distribution
    score_fn = target_gmm.score
    
    # Create methods to evaluate
    methods = {}
    particles_dict = {}
    convergence_dict = {}
    results_dict = {}
    
    # Create target info for improved methods
    target_info = {
        'n_modes': target_gmm.n_components,
        'centers': target_gmm.means,
        'covs': [cov for cov in target_gmm.covs]
    }
    
    # Add methods based on what's requested
    if 'ESCORT3D' in methods_to_run:
        methods['ESCORT3D'] = ESCORT3DAdapter(
            n_iter=n_iter, step_size=step_size, verbose=verbose, target_info=target_info)
    
    if 'SVGD' in methods_to_run:
        methods['SVGD'] = StableSVGDAdapter(
            n_iter=n_iter, 
            step_size=step_size, 
            verbose=verbose,
            target_info=target_info
        )
    
    if 'DVRL' in methods_to_run:
        try:
            # Initialize the DVRL model
            dvrl = DVRL(
                obs_dim=3,          # 3D state space
                action_dim=1,       # Simple 1D actions for testing
                h_dim=64,           # Hidden state dimension
                z_dim=3,            # Latent state dimension (matches state dimension)
                n_particles=100,    # Use fewer particles for stability
                continuous_actions=True
            )
            
            # Explicitly move model to CPU 
            dvrl = dvrl.to(torch.device('cpu'))
            
            # Create the adapter with the fixed implementation
            methods['DVRL'] = DVRLAdapter3D(dvrl, n_samples=n_particles)
        except Exception as e:
            print(f"Error initializing DVRL: {e}")
            # Fallback that returns initial particles
            methods['DVRL'] = lambda initial_particles, score_fn, target_samples=None, return_convergence=False: (
                (initial_particles.copy(), {"iterations": 0}) if return_convergence else initial_particles.copy()
            )
    
    if 'SIR' in methods_to_run:
        methods['SIR'] = SIRAdapter(n_iter=1)
    
    # Run each method
    for method_name, method in methods.items():
        if method is None:
            continue
            
        if verbose:
            print(f"Running {method_name}...")
        
        try:
            start_time = time.time()
            
            # Special handling for lambda fallback if used
            if callable(method) and not hasattr(method, 'fit_transform'):
                # This is our lambda fallback for DVRL
                particles, convergence = method(
                    initial_particles.copy(), score_fn, target_samples, return_convergence=True)
            else:
                # Normal method call
                particles, convergence = method.fit_transform(
                    initial_particles.copy(), score_fn, target_samples, return_convergence=True)
                
            end_time = time.time()
            
            # Store results
            particles_dict[method_name] = particles
            convergence_dict[method_name] = convergence
            
            # Evaluate the method
            evaluation = evaluate_method_3d(
                method_name, particles, target_gmm, target_samples, 
                runtime=end_time - start_time)
            
            # Add seed information to evaluation
            evaluation['Seed'] = seed
            
            # Store evaluation results
            results_dict[method_name] = evaluation
            
            if verbose:
                print(f"{method_name} completed in {end_time - start_time:.2f} seconds")
        except Exception as e:
            print(f"Error in {method_name}: {e}")
            traceback.print_exc()
            
            # Create fallback results
            particles = initial_particles.copy() + np.random.randn(*initial_particles.shape) * 0.1
            particles_dict[method_name] = particles
            convergence_dict[method_name] = {"iterations": 0}
            
            # Still evaluate with fallback particles
            evaluation = evaluate_method_3d(
                method_name, particles, target_gmm, target_samples, 
                runtime=0.0)
            
            # Add seed information
            evaluation['Seed'] = seed
            
            # Store evaluation results
            results_dict[method_name] = evaluation
    
    # Create results DataFrame
    results_df = pd.DataFrame.from_dict(results_dict, orient='index')
    
    return results_df, target_gmm, particles_dict, convergence_dict


def run_experiment_with_multiple_seeds(methods_to_run=None, n_runs=5, seeds=None, **kwargs):
    """
    # & Run experiment with multiple seeds for robust evaluation
    # &
    # & Args:
    # &     methods_to_run (list): Methods to evaluate 
    # &     n_runs (int): Number of runs with different seeds
    # &     seeds (list): List of seeds to use. If None, random seeds will be generated.
    # &     **kwargs: Additional arguments for method configuration
    # &
    # & Returns:
    # &     tuple: (mean_results_df, all_results_df, particles_dict_last_run, target_gmm)
    """
    print(f"Starting 3D GMM evaluation experiment with {n_runs} different seeds...")
    
    # Set default methods if not specified
    if methods_to_run is None:
        methods_to_run = ["ESCORT3D", "SVGD", "DVRL", "SIR"]
    
    # Generate random seeds if not provided
    if seeds is None:
        master_seed = np.random.randint(0, 10000)
        print(f"Master seed: {master_seed}")
        np.random.seed(master_seed)
        seeds = np.random.randint(0, 10000, size=n_runs)
    
    # Dictionary to store results for each method and run
    all_results = []
    
    # Dictionary to store particles for the last run (for visualization)
    last_run_particles = {}
    last_run_convergence = {}
    target_gmm = None
    
    # Run the experiment multiple times with different seeds
    for run_idx, seed in enumerate(seeds):
        print(f"\n=== Run {run_idx+1}/{n_runs} (Seed: {seed}) ===")
        
        # Run experiment with this seed
        results_df, curr_target_gmm, particles_dict, convergence_dict = run_experiment(
            methods_to_run=methods_to_run,
            seed=seed,
            verbose=kwargs.get('verbose', True),
            n_iter=kwargs.get('n_iter', 300),
            step_size=kwargs.get('step_size', 0.01)
        )
        
        # Store target GMM for reference (same across runs)
        if target_gmm is None:
            target_gmm = curr_target_gmm
        
        # For the last run, store particles and convergence info for visualization
        if run_idx == n_runs - 1:
            last_run_particles = particles_dict
            last_run_convergence = convergence_dict
        
        # Add run information to results
        results_df['Run'] = run_idx + 1
        
        # Add rows to all_results
        for method, row in results_df.iterrows():
            all_results.append(row.to_dict())
    
    # Convert all results to a DataFrame
    all_results_df = pd.DataFrame(all_results)
    
    # Calculate mean and standard error for each metric and method
    metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
              'Mode Coverage', 'Correlation Error', 'ESS', 'Sliced Wasserstein', 'Runtime (s)']
    
    # Initialize dictionary for mean results
    mean_results = []
    
    # Calculate statistics for each method
    for method in methods_to_run:
        method_data = all_results_df[all_results_df["Method"] == method]
        
        # Calculate mean and standard error for each metric
        method_stats = {"Method": method}
        
        for metric in metrics:
            values = method_data[metric].values
            mean_val = np.mean(values)
            se_val = sem(values)  # Standard error of the mean
            
            # Store mean and standard error
            method_stats[f"{metric}_mean"] = mean_val
            method_stats[f"{metric}_se"] = se_val
            method_stats[f"{metric}"] = f"{mean_val:.4f} ± {se_val:.4f}"
        
        mean_results.append(method_stats)
    
    # Convert to DataFrame
    mean_results_df = pd.DataFrame(mean_results)
    mean_results_df = mean_results_df.set_index('Method')
    
    # Print results
    print("\nResults Summary (Mean ± Standard Error):")
    display_cols = [metric for metric in metrics]
    print(mean_results_df[display_cols])
    
    return mean_results_df, all_results_df, last_run_particles, last_run_convergence, target_gmm


def visualize_results_with_error_bars(mean_results_df, all_results_df, particles_dict, 
                                    convergence_dict, target_gmm):
    """
    # & Create visualizations of the results with error bars for multiple runs
    # &
    # & Args:
    # &     mean_results_df: DataFrame with mean and standard error for each method and metric
    # &     all_results_df: DataFrame with results from all runs
    # &     particles_dict: Dictionary with particles from the last run
    # &     convergence_dict: Dictionary with convergence data from the last run
    # &     target_gmm: Target GMM distribution
    """
    # Create directory for plots
    plots_dir = os.path.join(SCRIPT_DIR, "plots_3d_multiseed")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Create visualizer
    viz = GMMVisualizer(cmap='viridis', figsize=(15, 12))
    
    # Get methods to visualize
    methods = list(mean_results_df.index)
    
    # Generate target samples for visualization - use fixed seed for consistent visualization
    n_viz_samples = 2000
    np.random.seed(42)  # Use a fixed seed for visualization samples
    target_samples = target_gmm.sample(n_viz_samples)
    
    # Visualize target distribution
    try:
        fig_target, _ = viz.visualize_3d(
            target_gmm, 
            title="Target 3D Correlated GMM Distribution", 
            show_components=True, 
            n_samples=1000)
        fig_target.savefig(os.path.join(plots_dir, "target_3d_distribution.png"), dpi=300)
        plt.close(fig_target)
    except Exception as e:
        print(f"Error visualizing target distribution: {e}")
    
    # Visualize each method's approximation from the last run
    for method_name, particles in particles_dict.items():
        try:
            # Create figure using the 3D visualization method
            fig, ax = viz.visualize_3d(
                target_gmm,
                title=f"{method_name} Approximation\n"
                    f"Mode Coverage: {mean_results_df.loc[method_name, 'Mode Coverage']}, "
                    f"Correlation Error: {mean_results_df.loc[method_name, 'Correlation Error']}",
                show_components=True,
                n_samples=0,  # Don't plot target samples
                alpha_surface=0.15  # More transparent surfaces
            )
            
            # Now add the method's particles
            ax.scatter(
                particles[:, 0], particles[:, 1], particles[:, 2],
                s=8, alpha=0.6, c='red', label=f'{method_name} Particles'
            )
            
            # Update legend
            ax.legend(loc='upper right', fontsize=10)
            
            # Save figure
            plt.savefig(os.path.join(plots_dir, f"{method_name}_3d_approximation.png"), dpi=300)
            plt.close(fig)
        except Exception as e:
            print(f"Error visualizing {method_name} approximation: {e}")
    
    # Plot convergence for last run
    if convergence_dict:
        try:
            plt.figure(figsize=(12, 6 * len(convergence_dict)))
            
            for i, (method_name, convergence) in enumerate(convergence_dict.items()):
                plt.subplot(len(convergence_dict), 1, i + 1)
                
                # Plot delta norm history
                history = convergence.get('delta_norm_history', [])
                if len(history) > 0:
                    # Clip extremely large values for better visualization
                    clipped_history = np.clip(history, 0, np.percentile(history, 95) * 2)
                    plt.plot(clipped_history, linewidth=2)
                    plt.xlabel('Iteration', fontsize=12)
                    plt.ylabel('Update Magnitude', fontsize=12)
                    plt.title(f'{method_name} Convergence (Last Run)', fontsize=14)
                    plt.grid(True, alpha=0.3)
                else:
                    plt.text(0.5, 0.5, "No convergence data available", 
                            ha='center', va='center', fontsize=14)
            
            plt.tight_layout()
            plt.savefig(os.path.join(plots_dir, "convergence_plots_last_run.png"), dpi=300)
            plt.close()
        except Exception as e:
            print(f"Error plotting convergence: {e}")
    
    # Figure: Metrics comparison with error bars
    try:
        plt.figure(figsize=(18, 12))
        
        # Extract metrics for plotting
        metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
                 'Mode Coverage', 'Correlation Error', 'ESS', 'Sliced Wasserstein']
        
        # Create bar plots for each metric with error bars
        for i, metric in enumerate(metrics):
            plt.subplot(2, 4, i+1)
            
            # Extract means and standard errors
            means = [mean_results_df.loc[method, f"{metric}_mean"] for method in methods]
            errors = [mean_results_df.loc[method, f"{metric}_se"] for method in methods]
            
            # For some metrics with potentially large values, clip for better visualization
            if metric in ['KL(Target||Method)', 'KL(Method||Target)']:
                means = np.clip(means, 0, min(20.0, max(means) * 2))
            
            # Colors for different methods
            colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown']
            
            # Create bar plot with error bars
            bars = plt.bar(methods, means, 
                          color=[colors[methods.index(m) % len(colors)] for m in methods],
                          yerr=errors, capsize=10, alpha=0.7)
            
            # Add value annotations
            for j, bar in enumerate(bars):
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + errors[j] + 0.01,
                        f'{means[j]:.4f}±{errors[j]:.4f}', 
                        ha='center', va='bottom', rotation=45,
                        fontsize=9)
            
            plt.title(metric)
            plt.xticks(rotation=45)
            plt.grid(axis='y', alpha=0.3)
            
            # For Mode Coverage and ESS, higher is better
            if metric in ['Mode Coverage', 'ESS']:
                plt.title(f"{metric} (higher is better)")
                plt.ylim(0, 1.1)
            else:
                plt.title(f"{metric} (lower is better)")
        
        # Add runtime comparison
        plt.subplot(2, 4, 8)
        runtime_means = [mean_results_df.loc[method, f"Runtime (s)_mean"] for method in methods]
        runtime_errors = [mean_results_df.loc[method, f"Runtime (s)_se"] for method in methods]
        
        bars = plt.bar(methods, runtime_means, 
                      color=[colors[methods.index(m) % len(colors)] for m in methods],
                      yerr=runtime_errors, capsize=10, alpha=0.7)
        
        for j, bar in enumerate(bars):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + runtime_errors[j] + 0.01,
                    f'{runtime_means[j]:.2f}±{runtime_errors[j]:.2f}s', 
                    ha='center', va='bottom', fontsize=9)
        
        plt.title("Runtime (seconds)")
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "metrics_comparison_with_errors.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error plotting metrics comparison with error bars: {e}")
    
    # Figure: Box plots showing distribution of results across runs
    try:
        plt.figure(figsize=(18, 12))
        
        for i, metric in enumerate(metrics):
            plt.subplot(2, 4, i+1)
            
            # Create box plot
            box_data = [all_results_df[all_results_df['Method'] == method][metric].values 
                      for method in methods]
            
            plt.boxplot(box_data, labels=methods, patch_artist=True,
                      boxprops=dict(facecolor='lightblue', color='blue'),
                      whiskerprops=dict(color='blue'),
                      capprops=dict(color='blue'),
                      medianprops=dict(color='red'))
            
            plt.title(f"{metric} Distribution Across Runs")
            plt.xticks(rotation=45)
            plt.grid(axis='y', alpha=0.3)
            
            # For Mode Coverage and ESS, higher is better
            if metric in ['Mode Coverage', 'ESS']:
                plt.ylim(0, 1.1)
        
        # Add runtime boxplot
        plt.subplot(2, 4, 8)
        runtime_box_data = [all_results_df[all_results_df['Method'] == method]['Runtime (s)'].values 
                          for method in methods]
        
        plt.boxplot(runtime_box_data, labels=methods, patch_artist=True,
                  boxprops=dict(facecolor='lightblue', color='blue'),
                  whiskerprops=dict(color='blue'),
                  capprops=dict(color='blue'),
                  medianprops=dict(color='red'))
        
        plt.title("Runtime (seconds) Distribution")
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "metrics_boxplots.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error plotting box plots: {e}")
    
    # Create summary table as an image
    try:
        plt.figure(figsize=(12, 8))
        plt.axis('off')
        
        # Prepare table data
        table_data = []
        table_data.append(['Method', 'Mode Coverage', 'Corr. Error', 'MMD', 'SWD', 'Runtime (s)'])
        
        for method in methods:
            row = [
                method,
                mean_results_df.loc[method, 'Mode Coverage'],
                mean_results_df.loc[method, 'Correlation Error'],
                mean_results_df.loc[method, 'MMD'],
                mean_results_df.loc[method, 'Sliced Wasserstein'],
                mean_results_df.loc[method, 'Runtime (s)']
            ]
            table_data.append(row)
        
        # Create table
        table = plt.table(cellText=table_data, loc='center', cellLoc='center', 
                          colWidths=[0.2, 0.2, 0.2, 0.15, 0.15, 0.15])
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1, 1.5)
        
        plt.title("Summary of Results (Mean ± Standard Error)", fontsize=16, pad=20)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "summary_table.png"), dpi=300, bbox_inches='tight')
        plt.close()
    except Exception as e:
        print(f"Error creating summary table: {e}")

    print(f"\nAll visualizations saved to {plots_dir}")


# ========================================
# Main Execution
# ========================================

if __name__ == "__main__":
    import argparse
    
    # Set up argument parser
    parser = argparse.ArgumentParser(description='ESCORT 3D Framework Evaluation with Multiple Seeds')
    parser.add_argument('--methods', nargs='+', 
                    default=['ESCORT3D', 'SVGD', 'DVRL', 'SIR'], 
                    help='Methods to evaluate (default: ESCORT3D SVGD DVRL SIR)')
    parser.add_argument('--n_runs', type=int, default=5,
                    help='Number of runs with different seeds (default: 5)')
    parser.add_argument('--n_iter', type=int, default=300, 
                    help='Number of iterations (default: 300)')
    parser.add_argument('--step_size', type=float, default=0.01,
                    help='Step size for updates (default: 0.01)')
    parser.add_argument('--no_verbose', action='store_false', dest='verbose',
                    help='Disable verbose output (default: verbose enabled)')
    parser.add_argument('--fixed_seeds', action='store_true',
                    help='Use fixed seeds instead of random ones (default: False)')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Configure parameters
    method_params = {
        'n_iter': args.n_iter,
        'step_size': args.step_size,
        'verbose': args.verbose,
    }
    
    # Use fixed seeds if requested
    if args.fixed_seeds:
        seeds = [42, 123, 456, 789, 101]  # Fixed seeds for reproducibility
        seeds = seeds[:args.n_runs]  # Truncate if fewer runs requested
    else:
        # Generate random master seed
        master_seed = np.random.randint(0, 10000)
        print(f"Master seed: {master_seed}")
        
        # Use master seed to generate seeds for individual runs
        np.random.seed(master_seed)
        seeds = np.random.randint(0, 10000, size=args.n_runs)
    
    # Print seeds being used
    print(f"Using seeds: {seeds}")
    
    # Run the experiment with specified methods and seeds
    mean_results_df, all_results_df, last_run_particles, last_run_convergence, target_gmm = run_experiment_with_multiple_seeds(
        methods_to_run=args.methods,
        n_runs=args.n_runs,
        seeds=seeds,
        **method_params
    )
    
    # Save results to CSV
    results_dir = os.path.join(SCRIPT_DIR, "results_3d_multiseed")
    os.makedirs(results_dir, exist_ok=True)
    mean_results_df.to_csv(os.path.join(results_dir, "escort_3d_mean_results.csv"))
    all_results_df.to_csv(os.path.join(results_dir, "escort_3d_all_results.csv"))
    
    # Visualize the results with error bars
    visualize_results_with_error_bars(
        mean_results_df, all_results_df, last_run_particles, 
        last_run_convergence, target_gmm
    )
    
    print("\nExperiment complete. Results saved to CSV and visualizations saved as PNG files.")
    print(f"CSV results saved in: {results_dir}")
    print(f"Visualizations saved in: {os.path.join(SCRIPT_DIR, 'plots_3d_multiseed')}")
    
    # Print overall ranking based on key metrics
    print("\nMethod Ranking by Key Metrics:")
    
    # Rank by Mode Coverage (higher is better)
    mc_ranking = mean_results_df.sort_values(by='Mode Coverage_mean', ascending=False).index.tolist()
    print(f"Mode Coverage: {', '.join(mc_ranking)}")
    
    # Rank by Correlation Error (lower is better)
    ce_ranking = mean_results_df.sort_values(by='Correlation Error_mean', ascending=True).index.tolist()
    print(f"Correlation Error: {', '.join(ce_ranking)}")
    
    # Rank by MMD (lower is better)
    mmd_ranking = mean_results_df.sort_values(by='MMD_mean', ascending=True).index.tolist()
    print(f"MMD: {', '.join(mmd_ranking)}")
    
    # Overall performance score (normalized weighted sum)
    # Higher mode coverage is better, lower correlation error and MMD is better
    methods = mean_results_df.index.tolist()
    
    # Normalize scores to range [0, 1] with proper direction
    mc_scores = mean_results_df['Mode Coverage_mean'] / mean_results_df['Mode Coverage_mean'].max()
    ce_scores = 1 - (mean_results_df['Correlation Error_mean'] / mean_results_df['Correlation Error_mean'].max())
    mmd_scores = 1 - (mean_results_df['MMD_mean'] / mean_results_df['MMD_mean'].max())
    
    # Compute overall score (equal weights for simplicity)
    overall_scores = (mc_scores * 0.4) + (ce_scores * 0.4) + (mmd_scores * 0.2)
    overall_ranking = overall_scores.sort_values(ascending=False).index.tolist()
    
    print(f"Overall Performance: {', '.join(overall_ranking)}")
    
    # Find best initialization type for each method
    run_ids = all_results_df['Run'].unique()
    print("\nPerformance by Random Seed:")
    for method in args.methods:
        method_data = all_results_df[all_results_df['Method'] == method]
        
        # Compute average MMD per seed
        seed_performance = {}
        for seed, seed_group in method_data.groupby('Seed'):
            seed_performance[seed] = seed_group['MMD'].mean()
        
        # Sort by performance (lower MMD is better)
        sorted_seeds = sorted(seed_performance.items(), key=lambda x: x[1])
        
        print(f"\n{method}:")
        print(f"  Best performance on seed: {sorted_seeds[0][0]} (MMD: {sorted_seeds[0][1]:.6f})")
        print(f"  Worst performance on seed: {sorted_seeds[-1][0]} (MMD: {sorted_seeds[-1][1]:.6f})")
