"""
    Enhanced ESCORT Framework Evaluation on 20D Multi-modal Correlated Distribution
    
    This script evaluates the ESCORT framework against other methods on a 
    challenging 20D multi-modal distribution with complex correlation structures.
    
    Key features:
    1. 20D distribution with varied correlation patterns across modes
    2. Visualization using dimensionality reduction techniques
    3. Comparative analysis of ESCORT, SVGD, DVRL, and SIR methods
    4. Metrics for high-dimensional distribution quality assessment
"""
import os
import sys
import time
import traceback
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import scipy
from scipy.stats import multivariate_normal
import pandas as pd
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import torch
import umap

# Import required libraries from provided modules
from belief_assessment.distributions import GMMDistribution
from belief_assessment.evaluation.visualize_distributions import GMMVisualizer

# Try to import ESCORT-related classes and helper functions
try:
    from escort.utils.kernels import RBFKernel
    from escort.gswd import GSWD
    from escort.svgd import SVGD, AdaptiveSVGD
    from dvrl.dvrl import DVRL
    from tests.evaluate_2d_1 import SIRAdapter
    from tests.evaluate_2d_1 import compute_mmd, compute_ess, compute_mode_coverage_2d
    from tests.evaluate_2d_1 import compute_correlation_error, compute_sliced_wasserstein_distance
    from tests.evaluate_2d_1 import estimate_kl_divergence_2d, evaluate_method
except ImportError as e:
    print(f"Import error: {e}")
    print("Implementing required adapter classes and metrics...")

    # Define fallback adapter for SIR
    class SIRAdapter:
        """
        # & Adapter for Sequential Importance Resampling
        """
        def __init__(self, n_iter=1):
            self.n_iter = n_iter
            
        def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
            try:
                particles = initial_particles.copy()
                
                for iter in range(self.n_iter):
                    # Try to get log probabilities
                    try:
                        _, log_probs = score_fn(particles, return_logp=True)
                    except:
                        log_probs = score_fn(particles)
                    
                    # Compute weights
                    probs = np.exp(log_probs - np.max(log_probs))
                    weights = probs / np.sum(probs)
                    
                    # Compute ESS
                    ess = 1.0 / np.sum(weights**2)
                    ess_ratio = ess / len(particles)
                    
                    # Resample if ESS is too low
                    if ess_ratio < 0.5:
                        print(f"Iteration {iter}: Resampling (ESS = {ess_ratio:.4f})")
                        # Multinomial resampling
                        indices = np.random.choice(
                            len(particles), len(particles), p=weights, replace=True
                        )
                        particles = particles[indices]
                        
                        # Add small noise - slightly larger for 20D
                        particles += np.random.randn(*particles.shape) * 0.15
                
                if return_convergence:
                    return particles, {"iterations": self.n_iter}
                else:
                    return particles
            except Exception as e:
                print(f"Error in SIR: {e}")
                return initial_particles

# Set random seed for reproducibility
np.random.seed(42)

# Get the directory of the current script for saving outputs
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))


class HighlyCorrelated20DGMMDistribution(GMMDistribution):
    """
    # & 20D GMM distribution with extremely challenging correlation structures
    # & designed to test correlation modeling capabilities in high dimensions
    """
    def __init__(self, name=None, seed=None):
        # Define 10 modes with varied and challenging correlation patterns
        means = np.zeros((10, 20))
        
        # Mode 1: Negative values in first half, positive in second half
        means[0, :10] = -2.0
        means[0, 10:] = 2.0
        
        # Mode 2: Alternating positive and negative
        means[1, ::2] = 2.0
        means[1, 1::2] = -2.0
        
        # Mode 3: Increasing values
        means[2, :] = np.linspace(-3.0, 3.0, 20)
        
        # Mode 4: Concentrated in dimensions 0-4
        means[3, 0:5] = 3.0
        
        # Mode 5: Concentrated in dimensions 5-9
        means[4, 5:10] = 3.0
        
        # Mode 6: Concentrated in dimensions 10-14
        means[5, 10:15] = 3.0
        
        # Mode 7: Concentrated in dimensions 15-19
        means[6, 15:20] = 3.0
        
        # Mode 8: Sinusoidal pattern
        means[7, :] = 2.0 * np.sin(np.linspace(0, 4*np.pi, 20))
        
        # Mode 9: Quadratic pattern
        x = np.linspace(-2, 2, 20)
        means[8, :] = 2.0 * (x**2 - 1)
        
        # Mode 10: All negative
        means[9, :] = -2.0
        
        # Create challenging correlation patterns
        # We need to ensure the matrices are positive definite
        covs = []
        
        # Mode 1: Block diagonal structure with correlations in 4 blocks of 5 dimensions
        cov1 = np.eye(20)
        # Block 1: dims 0-4
        for i in range(5):
            for j in range(i+1, 5):
                cov1[i, j] = cov1[j, i] = 0.7 if (i+j) % 2 == 0 else -0.7
        # Block 2: dims 5-9
        for i in range(5, 10):
            for j in range(i+1, 10):
                cov1[i, j] = cov1[j, i] = 0.8 if (i+j) % 2 == 0 else -0.5
        # Block 3: dims 10-14
        for i in range(10, 15):
            for j in range(i+1, 15):
                cov1[i, j] = cov1[j, i] = 0.6 if (i+j) % 2 == 0 else -0.6
        # Block 4: dims 15-19
        for i in range(15, 20):
            for j in range(i+1, 20):
                cov1[i, j] = cov1[j, i] = 0.65 if (i+j) % 2 == 0 else -0.65
        covs.append(cov1)
        
        # Mode 2: Strong correlations between odd and even dimensions
        cov2 = np.eye(20)
        for i in range(0, 20, 2):  # Even dimensions
            for j in range(1, 20, 2):  # Odd dimensions
                # Create a checkboard pattern of + and - correlations
                if (i//2 + j//2) % 2 == 0:
                    cov2[i, j] = cov2[j, i] = 0.75
                else:
                    cov2[i, j] = cov2[j, i] = -0.75
        covs.append(cov2)
        
        # Mode 3: Band diagonal correlation with decay
        cov3 = np.eye(20)
        for i in range(20):
            for j in range(i+1, 20):
                # Correlation decays with distance between dimensions
                if j - i <= 5:  # Only correlate within 5 steps
                    strength = 0.9 ** (j - i)
                    cov3[i, j] = cov3[j, i] = strength
        covs.append(cov3)
        
        # Mode 4: First 5 dimensions have strong correlation, others independent
        cov4 = np.eye(20)
        for i in range(5):
            for j in range(i+1, 5):
                cov4[i, j] = cov4[j, i] = 0.85
        covs.append(cov4)
        
        # Mode 5: Dimensions 5-9 have strong correlation, others independent
        cov5 = np.eye(20)
        for i in range(5, 10):
            for j in range(i+1, 10):
                cov5[i, j] = cov5[j, i] = 0.85
        covs.append(cov5)
        
        # Mode 6: Dimensions 10-14 have strong correlation, others independent
        cov6 = np.eye(20)
        for i in range(10, 15):
            for j in range(i+1, 15):
                cov6[i, j] = cov6[j, i] = 0.85
        covs.append(cov6)
        
        # Mode 7: Dimensions 15-19 have strong correlation, others independent
        cov7 = np.eye(20)
        for i in range(15, 20):
            for j in range(i+1, 20):
                cov7[i, j] = cov7[j, i] = 0.85
        covs.append(cov7)
        
        # Mode 8: Hierarchical correlation structure (stronger within groups)
        cov8 = np.eye(20)
        # Group 1: 0-4
        for i in range(5):
            for j in range(i+1, 5):
                cov8[i, j] = cov8[j, i] = 0.7
        # Group 2: 5-9
        for i in range(5, 10):
            for j in range(i+1, 10):
                cov8[i, j] = cov8[j, i] = 0.7
        # Group 3: 10-14
        for i in range(10, 15):
            for j in range(i+1, 15):
                cov8[i, j] = cov8[j, i] = 0.7
        # Group 4: 15-19
        for i in range(15, 20):
            for j in range(i+1, 20):
                cov8[i, j] = cov8[j, i] = 0.7
        # Cross-group correlations
        for g1 in range(0, 20, 5):  # First dim of each group
            for g2 in range(g1+5, 20, 5):  # Connect to first dims of later groups
                cov8[g1, g2] = cov8[g2, g1] = 0.3
        covs.append(cov8)
        
        # Mode 9: Long-range correlation between first and last dimensions
        cov9 = np.eye(20)
        # Strong correlation between far-apart dimensions
        for i in range(10):
            j = 19 - i  # Pair with opposite end
            cov9[i, j] = cov9[j, i] = 0.7 if i % 2 == 0 else -0.7
        covs.append(cov9)
        
        # Mode 10: Near-independent dimensions with low correlation
        cov10 = np.eye(20)
        # Very weak correlations everywhere
        for i in range(20):
            for j in range(i+1, 20):
                if np.random.rand() < 0.3:  # Sparse correlations
                    cov10[i, j] = cov10[j, i] = 0.1 * np.random.randn()
        covs.append(cov10)
        
        # Ensure covariance matrices are positive definite
        for i in range(len(covs)):
            # Add small regularization to diagonal
            covs[i] = covs[i] + 1e-5 * np.eye(20)
            
            # Check eigenvalues and adjust if needed
            eigvals = np.linalg.eigvalsh(covs[i])
            if np.min(eigvals) <= 0:
                # Add regularization to make positive definite
                covs[i] = covs[i] + (abs(np.min(eigvals)) + 1e-4) * np.eye(20)
        
        # Convert list to array
        covs = np.array(covs)
        
        # Different weights to challenge the models
        weights = np.array([0.12, 0.12, 0.10, 0.08, 0.08, 0.08, 0.08, 0.12, 0.12, 0.10])
        weights = weights / np.sum(weights)  # Ensure they sum to 1
        
        # Initialize base class
        super().__init__(means, covs, weights, name=name or "Highly Correlated 20D GMM", seed=seed)


# ========================================
# 20D Evaluation Metrics
# ========================================
def get_method(method_name, **kwargs):
    """
    # & Factory function to create method instances based on name.
    # &
    # & Args:
    # &     method_name (str): Name of the method to create
    # &     **kwargs: Additional arguments for method configuration
    # &        
    # & Returns:
    # &     object: Method instance
    """
    target_info = kwargs.get('target_info', None)
    n_iter = kwargs.get('n_iter', 300)
    step_size = kwargs.get('step_size', 0.01)
    verbose = kwargs.get('verbose', True)
    
    if method_name.upper() == "ESCORT20D":
        return ESCORT20DAdapter(
            n_iter=n_iter,
            step_size=step_size,
            verbose=verbose,
            target_info=target_info
        )
    elif method_name.upper() == "SVGD":
        return StableSVGD20DAdapter(
            n_iter=n_iter,
            step_size=step_size,
            verbose=verbose,
            target_info=target_info
        )
    elif method_name.upper() == "DVRL":
        try:
            # Initialize the DVRL model
            dvrl = DVRL(
                obs_dim=20,         # 20D state space
                action_dim=1,       # Simple 1D actions for testing
                h_dim=128,          # Larger hidden state dimension for 20D
                z_dim=20,           # Latent state dimension (matches state dimension)
                n_particles=100,    # Use fewer particles for stability
                continuous_actions=True
            )
            
            # Explicitly move model to CPU 
            dvrl = dvrl.to(torch.device('cpu'))
            
            # Create the adapter with the implementation
            return DVRLAdapter20D(dvrl, n_samples=kwargs.get('n_particles', 1000))
        except Exception as e:
            print(f"Error initializing DVRL: {e}")
            # Create a fallback adapter
            class DVRLFallbackAdapter:
                def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
                    particles = initial_particles.copy()
                    particles += np.random.randn(*particles.shape) * 0.1
                    convergence = {"iterations": 0}
                    return (particles, convergence) if return_convergence else particles
            
            return DVRLFallbackAdapter()
    elif method_name.upper() == "SIR":
        return SIRAdapter(n_iter=kwargs.get('sir_iter', 1))
    else:
        raise ValueError(f"Unknown method: {method_name}")


def run_experiment_with_multiple_seeds(methods_to_run=None, n_runs=5, seeds=None, **kwargs):
    """
    # & Run the full 20D evaluation experiment with multiple random seeds.
    # &
    # & Args:
    # &     methods_to_run (list): List of method names to run. If None, runs all methods.
    # &     n_runs (int): Number of runs with different seeds.
    # &     seeds (list): List of seeds to use. If None, random seeds will be generated.
    # &     **kwargs: Additional arguments for method configuration
    # &        
    # & Returns:
    # &     tuple: (DataFrame with mean results, DataFrame with all results, Dictionary with all particles)
    """
    print(f"Starting 20D GMM evaluation experiment with {n_runs} different initializations...")
    
    # Set default methods if not specified
    if methods_to_run is None:
        methods_to_run = ["ESCORT20D", "SVGD", "DVRL", "SIR"]
    
    # Generate random seeds if not provided
    if seeds is None:
        master_seed = np.random.randint(0, 10000)
        print(f"Master seed: {master_seed}")
        np.random.seed(master_seed)
        seeds = np.random.randint(0, 10000, size=n_runs)
    
    # Dictionary to store results for each method and run
    all_results = []
    
    # Dictionary to store particles for each method and run
    all_particles = {method: [] for method in methods_to_run}
    
    # Dictionary to store convergence info for each method and run
    all_convergence = {method: [] for method in methods_to_run}
    
    # Dictionary to store runtime for each method and run
    methods_runtime = {}
    
    # Set a fixed seed for the target distribution - we want this to be the same across all runs
    # This creates a controlled experiment where only the initializations vary
    np.random.seed(42)
    
    # Create target distribution (same for all runs)
    target_gmm = HighlyCorrelated20DGMMDistribution()
    
    # Generate target samples for evaluation (same for all runs)
    n_eval_samples = 2000
    target_samples = target_gmm.sample(n_eval_samples)
    
    # Create target info for improved methods
    target_info = {
        'n_modes': len(target_gmm.means),
        'centers': target_gmm.means,
        'covs': [cov for cov in target_gmm.covs]
    }
    
    # Define score function (same for all runs)
    score_fn = target_gmm.score
    
    # Run the experiment multiple times with different initializations
    for run_idx, seed in enumerate(seeds):
        print(f"\n=== Run {run_idx+1}/{n_runs} (Initialization Seed: {seed}) ===")
        
        # Set the random seed for this run's initialization only
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
        
        # Different initialization strategies for each run
        n_particles = 1000
        initialization_type = run_idx % 4  # Cycle through 4 different initialization strategies
        
        if initialization_type == 0:
            # Standard Gaussian initialization
            initial_particles = np.random.randn(n_particles, 20) * 3.0
            init_description = "Standard Gaussian"
        elif initialization_type == 1:
            # Uniform initialization
            initial_particles = np.random.uniform(-5, 5, (n_particles, 20))
            init_description = "Uniform [-5, 5]"
        elif initialization_type == 2:
            # Concentrated initialization around random center
            centers = target_gmm.means
            center_idx = np.random.choice(len(centers))
            center = centers[center_idx]
            initial_particles = np.random.randn(n_particles, 20) * 1.0 + center
            init_description = f"Concentrated around mode {center_idx}"
        else:
            # Multimodal initialization
            centers = target_gmm.means
            # Choose random subset of centers
            subset_size = min(4, len(centers))
            center_indices = np.random.choice(len(centers), subset_size, replace=False)
            selected_centers = centers[center_indices]
            
            # Distribute particles among these centers
            particles_per_center = n_particles // subset_size
            initial_particles = np.zeros((n_particles, 20))
            
            for i, center in enumerate(selected_centers):
                start_idx = i * particles_per_center
                end_idx = start_idx + particles_per_center if i < subset_size - 1 else n_particles
                initial_particles[start_idx:end_idx] = center + np.random.randn(end_idx - start_idx, 20) * 1.0
            
            init_description = f"Multimodal ({subset_size} centers)"
        
        print(f"  Using {init_description} initialization")
        
        # Create methods for this run
        for method_name in methods_to_run:
            print(f"  Running {method_name}...")
            
            # Add method-specific perturbation to make each method start slightly differently
            method_particles = initial_particles.copy()
            method_particles += np.random.randn(n_particles, 20) * 0.1
            
            # Create method instance with kwargs
            method = get_method(method_name, target_info=target_info, **kwargs)
            
            # Time the execution
            start_time = time.time()
            
            # Run the method
            try:
                particles, convergence = method.fit_transform(
                    method_particles, 
                    score_fn,
                    target_samples=target_samples,
                    return_convergence=True
                )
                
                runtime = time.time() - start_time
                print(f"    Completed in {runtime:.2f} seconds")
                
                # Store particles for this method and run
                all_particles[method_name].append(particles)
                all_convergence[method_name].append(convergence)
                
                # Store runtime
                if method_name not in methods_runtime:
                    methods_runtime[method_name] = []
                methods_runtime[method_name].append(runtime)
                
                # Evaluate and store results
                method_results = evaluate_method_20d(
                    method_name, particles, target_gmm, target_samples, runtime=runtime
                )
                
            except Exception as e:
                print(f"    Error running {method_name}: {e}")
                traceback.print_exc()
                
                # Create fallback results
                print(f"    Using fallback for {method_name}")
                particles = method_particles.copy() + np.random.randn(*method_particles.shape) * 0.1
                runtime = time.time() - start_time
                
                # Store fallback particles and dummy convergence info
                all_particles[method_name].append(particles)
                all_convergence[method_name].append({"iterations": 0})
                
                # Store runtime
                if method_name not in methods_runtime:
                    methods_runtime[method_name] = []
                methods_runtime[method_name].append(runtime)
                
                # Create fallback evaluation
                method_results = evaluate_method_20d(
                    method_name, particles, target_gmm, target_samples, runtime=runtime
                )
            
            # Add run information
            method_results["Run"] = run_idx + 1
            method_results["Seed"] = seed
            method_results["Initialization"] = init_description
            
            all_results.append(method_results)
    
    # Convert all results to a DataFrame
    all_results_df = pd.DataFrame(all_results)
    
    # Calculate mean and standard error for each metric and method
    metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
              'Mode Coverage', 'Correlation Error', 'ESS', 'Sliced Wasserstein', 'Runtime (s)']
    
    # Initialize dictionary for mean results
    mean_results = []
    
    # Calculate statistics for each method
    for method in methods_to_run:
        method_data = all_results_df[all_results_df["Method"] == method]
        
        # Calculate mean and standard error for each metric
        method_stats = {"Method": method}
        
        for metric in metrics:
            values = method_data[metric].values
            mean_val = np.mean(values)
            se_val = scipy.stats.sem(values)  # Standard error of the mean
            
            # Store mean and standard error
            method_stats[f"{metric}_mean"] = mean_val
            method_stats[f"{metric}_se"] = se_val
            method_stats[f"{metric}"] = f"{mean_val:.6f} ± {se_val:.6f}"
        
        mean_results.append(method_stats)
    
    # Convert to DataFrame
    mean_results_df = pd.DataFrame(mean_results)
    mean_results_df = mean_results_df.set_index('Method')
    
    # Print results
    print("\nResults Summary (Mean ± Standard Error):")
    display_cols = [metric for metric in metrics]
    print(mean_results_df[display_cols])
    
    # Compute average runtimes
    print("\nAverage Runtimes:")
    for method_name, runtimes in methods_runtime.items():
        if runtimes:
            avg_runtime = np.mean(runtimes)
            print(f"  {method_name}: {avg_runtime:.2f} seconds")
    
    return mean_results_df, all_results_df, all_particles, all_convergence, target_gmm

def compute_mmd_20d(particles, target_samples, bandwidth=None):
    """
    # & Compute Maximum Mean Discrepancy between particles and target for 20D
    # & Using more efficient computation for high dimensions
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     target_samples (np.ndarray): Target distribution samples
    # &     bandwidth (float, optional): Kernel bandwidth
    # &
    # & Returns:
    # &     float: MMD value
    """
    # Subsample for computational efficiency in 20D
    max_samples = 500  # Smaller sample for higher dimension
    
    if len(particles) > max_samples:
        p_indices = np.random.choice(len(particles), max_samples, replace=False)
        particles_sub = particles[p_indices]
    else:
        particles_sub = particles
        
    if len(target_samples) > max_samples:
        t_indices = np.random.choice(len(target_samples), max_samples, replace=False)
        target_sub = target_samples[t_indices]
    else:
        target_sub = target_samples
    
    n_p = len(particles_sub)
    n_t = len(target_sub)
    
    # Use median heuristic if bandwidth not provided
    if bandwidth is None:
        # Compute distances on a subset for efficiency
        n_subset = min(100, n_p)
        subset_p = particles_sub[:n_subset]
        
        dists = []
        for i in range(min(20, len(subset_p))):  # Even smaller subset for 20D
            xi = subset_p[i]
            diff = subset_p - xi
            dists.extend(np.sum(diff**2, axis=1).tolist())
            
        if dists:
            # Use larger bandwidth for high dimensions to avoid kernel collapse
            bandwidth = np.median(dists) * 2.0  # Scale up for 20D
        else:
            bandwidth = 20.0  # Default higher bandwidth for 20D
    
    # RBF kernel function with bandwidth adjustment
    def kernel(x, y):
        return np.exp(-np.sum((x - y)**2) / bandwidth)
    
    # Compute MMD terms more efficiently using matrix operations
    # Compute pp_sum - use vectorization for better performance
    pp_sum = 0
    # Use batched computation for large arrays
    batch_size = 50  # Process in small batches
    for i in range(0, n_p, batch_size):
        i_end = min(i + batch_size, n_p)
        for j in range(i, n_p, batch_size):
            j_end = min(j + batch_size, n_p)
            for ii in range(i, i_end):
                for jj in range(max(ii+1, j), j_end):
                    pp_sum += kernel(particles_sub[ii], particles_sub[jj])
    pp_sum = 2 * pp_sum / (n_p * (n_p - 1)) if n_p > 1 else 0
    
    # Compute tt_sum
    tt_sum = 0
    for i in range(0, n_t, batch_size):
        i_end = min(i + batch_size, n_t)
        for j in range(i, n_t, batch_size):
            j_end = min(j + batch_size, n_t)
            for ii in range(i, i_end):
                for jj in range(max(ii+1, j), j_end):
                    tt_sum += kernel(target_sub[ii], target_sub[jj])
    tt_sum = 2 * tt_sum / (n_t * (n_t - 1)) if n_t > 1 else 0
    
    # Compute pt_sum
    pt_sum = 0
    for i in range(0, n_p, batch_size):
        i_end = min(i + batch_size, n_p)
        for j in range(0, n_t, batch_size):
            j_end = min(j + batch_size, n_t)
            for ii in range(i, i_end):
                for jj in range(j, j_end):
                    pt_sum += kernel(particles_sub[ii], target_sub[jj])
    pt_sum = pt_sum / (n_p * n_t)
    
    mmd = pp_sum + tt_sum - 2 * pt_sum
    return max(0, mmd)  # Ensure non-negative


def compute_ess_20d(particles, score_fn):
    """
    # & Compute Effective Sample Size for particles in 20D
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     score_fn (callable): Score function
    # &
    # & Returns:
    # &     float: ESS value normalized to [0,1]
    """
    try:
        # Try to get log probabilities
        try:
            _, log_probs = score_fn(particles, return_logp=True)
        except:
            log_probs = score_fn(particles)
        
        # Compute weights
        max_log_prob = np.max(log_probs)
        probs = np.exp(log_probs - max_log_prob)
        weights = probs / np.sum(probs)
        
        # Compute ESS
        ess = 1.0 / np.sum(weights**2)
        return ess / len(particles)  # Normalized ESS
    except Exception as e:
        print(f"Error computing ESS: {e}")
        return 0.0


def compute_mode_coverage_20d(particles, gmm, threshold=0.05):
    """
    # & Compute mode coverage ratio for 20D particles
    # & Uses higher threshold due to dimensionality challenges
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     gmm (GMMDistribution): Target GMM distribution
    # &     threshold (float): Coverage threshold
    # &
    # & Returns:
    # &     float: Mode coverage ratio in [0,1]
    """
    n_modes = gmm.n_components
    mode_centers = gmm.means
    
    # Determine if each mode is covered by particles
    modes_covered = np.zeros(n_modes, dtype=bool)
    
    # For each mode, check if there are enough particles nearby
    for i in range(n_modes):
        center = mode_centers[i]
        cov = gmm.covs[i]
        
        # Compute Mahalanobis distance for each particle
        try:
            # Try using Mahalanobis distance with covariance
            inv_cov = np.linalg.inv(cov)
            distances = []
            
            # Process in batches for memory efficiency in 20D
            batch_size = 100
            for j in range(0, len(particles), batch_size):
                batch = particles[j:j+batch_size]
                diff = batch - center
                # Compute Mahalanobis distance for batch
                for k in range(len(batch)):
                    d = diff[k]
                    distances.append(np.sqrt(np.dot(np.dot(d, inv_cov), d)))
            
            distances = np.array(distances)
        except:
            # Fallback to Euclidean distance
            diff = particles - center
            distances = np.sqrt(np.sum(diff**2, axis=1))
        
        # Count particles within threshold
        # For 20D, we use a larger threshold than in lower dimensions
        close_particles = np.sum(distances < 7.0)  # Increased threshold for 20D space
        
        # Check if enough particles are near this mode
        # Lower threshold for 20D due to curse of dimensionality
        if close_particles >= threshold * len(particles) / n_modes:
            modes_covered[i] = True
    
    return np.mean(modes_covered)


def compute_correlation_error_20d(particles, gmm):
    """
    # & Compute error in capturing correlation structure for 20D distributions
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     gmm (GMMDistribution): Target GMM distribution
    # &
    # & Returns:
    # &     float: Correlation structure error
    """
    # Cluster particles to identify modes
    n_modes = gmm.n_components
    
    # Skip if too few particles
    if len(particles) < n_modes * 20:  # Need more particles for reliable 20D correlation
        return 1.0  # Maximum error
    
    try:
        # Use KMeans to cluster particles
        kmeans = KMeans(n_clusters=n_modes, random_state=42, max_iter=300, n_init=10)
        cluster_labels = kmeans.fit_predict(particles)
        cluster_centers = kmeans.cluster_centers_
        
        # Match clusters to GMM modes (using Hungarian algorithm if available)
        try:
            from scipy.optimize import linear_sum_assignment
            
            # Compute distance matrix between cluster centers and GMM modes
            cost_matrix = np.zeros((n_modes, n_modes))
            for i in range(n_modes):
                for j in range(n_modes):
                    cost_matrix[i, j] = np.sum((cluster_centers[i] - gmm.means[j])**2)
            
            # Solve assignment problem
            row_ind, col_ind = linear_sum_assignment(cost_matrix)
            
            # Reorder cluster labels
            mode_map = {row_ind[i]: col_ind[i] for i in range(n_modes)}
            mode_labels = np.array([mode_map[label] for label in cluster_labels])
        except:
            # Fallback to simple nearest center assignment
            mode_labels = cluster_labels
        
        # Compute correlation error for each mode
        mode_errors = []
        
        # We'll focus on principal correlations to handle dimensionality
        for i in range(n_modes):
            # Get particles in this mode
            mode_particles = particles[mode_labels == i]
            
            if len(mode_particles) > 50:  # Need more samples for 20D covariance
                # Compute empirical covariance
                empirical_cov = np.cov(mode_particles, rowvar=False)
                
                # Get true covariance for this mode
                true_cov = gmm.covs[i]
                
                # Focus on strongest correlations
                # Create masks for strong correlations in true covariance
                mask = np.abs(true_cov) > 0.3
                np.fill_diagonal(mask, False)  # Exclude diagonal elements
                
                if np.any(mask):
                    # Error on strong correlations
                    strong_true = true_cov[mask]
                    strong_empirical = empirical_cov[mask]
                    
                    # Correlation error: MSE on important correlations
                    corr_error = np.mean((strong_true - strong_empirical)**2)
                    mode_errors.append(corr_error)
                else:
                    # No strong correlations, compute standard error
                    # Use top eigen-directions to capture principal components of covariance
                    true_eigvals, true_eigvecs = np.linalg.eigh(true_cov)
                    emp_eigvals, emp_eigvecs = np.linalg.eigh(empirical_cov)
                    
                    # Use top 5 eigenvalues for comparison
                    top_k = 5
                    true_top = true_eigvals[-top_k:]
                    emp_top = emp_eigvals[-top_k:]
                    
                    # Error on principal eigenvalues
                    eigval_error = np.mean((true_top - emp_top)**2) / np.mean(true_top**2)
                    mode_errors.append(eigval_error)
            else:
                mode_errors.append(1.0)  # Maximum error if too few particles
        
        return np.mean(mode_errors)
    except Exception as e:
        print(f"Error computing correlation error: {e}")
        return 1.0  # Maximum error on failure


def compute_sliced_wasserstein_distance_20d(particles, target_samples, n_projections=50):
    """
    # & Compute Sliced Wasserstein Distance for 20D distributions
    # & Using more projections than in lower dimensions to better capture the structure
    # &
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     target_samples (np.ndarray): Target distribution samples
    # &     n_projections (int): Number of random projections
    # &
    # & Returns:
    # &     float: Sliced Wasserstein Distance
    """
    try:
        # Generate random projection directions
        # For 20D, we need more projections to capture the structure
        directions = np.random.randn(n_projections, 20)
        directions = directions / np.linalg.norm(directions, axis=1, keepdims=True)
        
        # Subsample for computational efficiency in 20D
        max_samples = 1000
        if len(particles) > max_samples:
            p_indices = np.random.choice(len(particles), max_samples, replace=False)
            particles_sub = particles[p_indices]
        else:
            particles_sub = particles
            
        if len(target_samples) > max_samples:
            t_indices = np.random.choice(len(target_samples), max_samples, replace=False)
            target_sub = target_samples[t_indices]
        else:
            target_sub = target_samples
        
        # Compute Sliced Wasserstein Distance
        swd = 0.0
        
        for direction in directions:
            # Project samples onto this direction
            particles_proj = particles_sub @ direction
            target_proj = target_sub @ direction
            
            # Sort projections
            particles_proj = np.sort(particles_proj)
            target_proj = np.sort(target_proj)
            
            # Compute 1-Wasserstein distance for this projection
            if len(particles_proj) != len(target_proj):
                # Interpolate to match lengths
                if len(particles_proj) > len(target_proj):
                    indices = np.linspace(0, len(target_proj)-1, len(particles_proj))
                    target_proj_interp = np.interp(indices, np.arange(len(target_proj)), target_proj)
                    w_dist = np.mean(np.abs(particles_proj - target_proj_interp))
                else:
                    indices = np.linspace(0, len(particles_proj)-1, len(target_proj))
                    particles_proj_interp = np.interp(indices, np.arange(len(particles_proj)), particles_proj)
                    w_dist = np.mean(np.abs(particles_proj_interp - target_proj))
            else:
                w_dist = np.mean(np.abs(particles_proj - target_proj))
            
            swd += w_dist
        
        return swd / n_projections
    except Exception as e:
        print(f"Error computing sliced Wasserstein distance: {e}")
        return float('inf')


def estimate_kl_divergence_20d(particles, gmm, direction='forward', n_bins=10):
    """
    # & Estimate KL divergence for 20D case using projection-based approach
    # & 
    # & Args:
    # &     particles (np.ndarray): Particles to evaluate
    # &     gmm (GMMDistribution): Target GMM distribution
    # &     direction (str): 'forward' for KL(p||q) or 'reverse' for KL(q||p)
    # &     n_bins (int): Number of bins for histogram estimation (reduced for 20D)
    # &
    # & Returns:
    # &     float: Estimated KL divergence
    """
    # For 20D, direct estimation is extremely challenging, so we use:
    # 1. Random 1D projections 
    # 2. PCA to capture highest variance directions
    # 3. Combine these for the final estimate
    
    try:
        # Generate samples from the GMM for reverse KL
        if direction == 'reverse':
            gmm_samples = gmm.sample(len(particles))
            p_samples = gmm_samples
            q_samples = particles
        else:  # forward
            p_samples = particles
            q_samples = gmm.sample(len(particles))
        
        # 1. Compute KL using random 1D projections
        n_projections = 20  # More projections for 20D
        kl_projections = []
        
        # Generate random unit vectors
        projections = np.random.randn(n_projections, 20)
        projections = projections / np.linalg.norm(projections, axis=1, keepdims=True)
        
        for i in range(n_projections):
            # Project samples onto this direction
            p_proj = p_samples @ projections[i]
            q_proj = q_samples @ projections[i]
            
            # Define bins for histograms
            all_proj = np.concatenate([p_proj, q_proj])
            # Use fewer bins for 20D to avoid sparsity
            bin_edges = np.linspace(np.min(all_proj), np.max(all_proj), n_bins+1)
            
            # Compute histograms
            p_hist, _ = np.histogram(p_proj, bins=bin_edges, density=True)
            q_hist, _ = np.histogram(q_proj, bins=bin_edges, density=True)
            
            # Add small epsilon to avoid division by zero or log(0)
            epsilon = 1e-10
            p_hist = p_hist + epsilon
            q_hist = q_hist + epsilon
            
            # Normalize histograms
            p_hist = p_hist / np.sum(p_hist)
            q_hist = q_hist / np.sum(q_hist)
            
            # Compute KL divergence
            kl_div = np.sum(p_hist * np.log(p_hist / q_hist))
            kl_projections.append(kl_div)
        
        # 2. Compute KL for principal components
        try:
            # Fit PCA on combined samples to get common basis
            pca = PCA(n_components=5)  # Use top 5 principal components
            all_samples = np.vstack([p_samples, q_samples])
            pca.fit(all_samples)
            
            # Project to PCA space
            p_pca = pca.transform(p_samples)
            q_pca = pca.transform(q_samples)
            
            # Compute KL for each principal component
            kl_pca = []
            for i in range(5):  # For each principal component
                p_proj = p_pca[:, i]
                q_proj = q_pca[:, i]
                
                # Define bins
                all_proj = np.concatenate([p_proj, q_proj])
                bin_edges = np.linspace(np.min(all_proj), np.max(all_proj), n_bins+1)
                
                # Compute histograms
                p_hist, _ = np.histogram(p_proj, bins=bin_edges, density=True)
                q_hist, _ = np.histogram(q_proj, bins=bin_edges, density=True)
                
                # Add epsilon and normalize
                p_hist = p_hist + epsilon
                q_hist = q_hist + epsilon
                p_hist = p_hist / np.sum(p_hist)
                q_hist = q_hist / np.sum(q_hist)
                
                # Compute KL divergence
                kl_div = np.sum(p_hist * np.log(p_hist / q_hist))
                kl_pca.append(kl_div)
            
            # Also compute 2D KL for pairs of principal components
            kl_pca_2d = []
            for i in range(min(3, p_pca.shape[1])):
                for j in range(i+1, min(4, p_pca.shape[1])):
                    p_proj_i = p_pca[:, i]
                    p_proj_j = p_pca[:, j]
                    q_proj_i = q_pca[:, i]
                    q_proj_j = q_pca[:, j]
                    
                    # Compute 2D histograms (with fewer bins for 20D)
                    p_hist_2d, _, _ = np.histogram2d(
                        p_proj_i, p_proj_j, bins=n_bins, density=True)
                    q_hist_2d, _, _ = np.histogram2d(
                        q_proj_i, q_proj_j, bins=n_bins, density=True)
                    
                    # Add epsilon and normalize
                    p_hist_2d = p_hist_2d + epsilon
                    q_hist_2d = q_hist_2d + epsilon
                    p_hist_2d = p_hist_2d / np.sum(p_hist_2d)
                    q_hist_2d = q_hist_2d / np.sum(q_hist_2d)
                    
                    # Compute KL divergence
                    mask = (p_hist_2d > epsilon) & (q_hist_2d > epsilon)
                    if np.any(mask):
                        kl_div_2d = np.sum(p_hist_2d[mask] * np.log(p_hist_2d[mask] / q_hist_2d[mask]))
                        kl_pca_2d.append(kl_div_2d)
            
            # 3. Combine the estimates with appropriate weights
            # Weight the PCA components more since they capture the principal variance
            if kl_pca and kl_projections:
                # Average of random projections
                kl_random = np.mean(kl_projections)
                # Average of 1D PCA projections
                kl_pca_1d = np.mean(kl_pca)
                # Average of 2D PCA projections (if available)
                kl_pca_2d_mean = np.mean(kl_pca_2d) if kl_pca_2d else kl_pca_1d
                
                # Weighted combination: give more weight to PCA
                kl_estimate = (0.3 * kl_random + 0.4 * kl_pca_1d + 0.3 * kl_pca_2d_mean)
                return kl_estimate
            elif kl_projections:
                return np.mean(kl_projections)
            else:
                # Fallback if all else fails
                return 10.0  # Default high value
        except Exception as e:
            print(f"Error in PCA-based KL estimation: {e}")
            if kl_projections:
                return np.mean(kl_projections)
            else:
                return 10.0  # Default high value
    except Exception as e:
        print(f"Error estimating KL divergence: {e}")
        return float('inf')


def evaluate_method_20d(method_name, particles, gmm, target_samples, runtime=None):
    """
    # & Evaluate method performance using multiple metrics for 20D case
    # &
    # & Args:
    # &     method_name (str): Name of the method
    # &     particles (np.ndarray): Particles from the method
    # &     gmm (GMMDistribution): Target GMM distribution
    # &     target_samples (np.ndarray): Target distribution samples
    # &     runtime (float, optional): Method runtime in seconds
    # &
    # & Returns:
    # &     dict: Evaluation metrics
    """
    results = {
        'Method': method_name,
        'MMD': compute_mmd_20d(particles, target_samples),
        'KL(Target||Method)': estimate_kl_divergence_20d(target_samples, gmm, direction='forward'),
        'KL(Method||Target)': estimate_kl_divergence_20d(particles, gmm, direction='reverse'),
        'Mode Coverage': compute_mode_coverage_20d(particles, gmm),
        'Correlation Error': compute_correlation_error_20d(particles, gmm),
        'ESS': compute_ess_20d(particles, gmm.score),
        'Sliced Wasserstein': compute_sliced_wasserstein_distance_20d(particles, target_samples),
    }
    
    if runtime is not None:
        results['Runtime (s)'] = runtime
    
    return results


# ========================================
# Method Adapters for 20D
# ========================================

class DVRLAdapter20D:
    """
    # & Adapter for DVRL to match interface with other methods for 20D data
    """
    def __init__(self, dvrl_model, n_samples=1000):
        self.dvrl_model = dvrl_model
        self.n_samples = n_samples
        
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Adapts DVRL to the common interface for 20D
        """
        import numpy as np
        import torch
        
        dim = initial_particles.shape[1]  # Get dimensionality of initial particles
        particles = initial_particles.copy()
        
        try:
            # Try to find what methods are available in DVRL
            if hasattr(self.dvrl_model, 'sample') and callable(getattr(self.dvrl_model, 'sample')):
                try:
                    samples = self.dvrl_model.sample(n_samples=self.n_samples)
                    if isinstance(samples, torch.Tensor):
                        particles = samples.detach().cpu().numpy()
                except Exception as e:
                    print(f"Error using DVRL sample method: {e}")
            elif hasattr(self.dvrl_model, 'generate') and callable(getattr(self.dvrl_model, 'generate')):
                try:
                    samples = self.dvrl_model.generate(n_samples=self.n_samples)
                    if isinstance(samples, torch.Tensor):
                        particles = samples.detach().cpu().numpy()
                except Exception as e:
                    print(f"Error using DVRL generate method: {e}")
            elif hasattr(self.dvrl_model, 'forward') and callable(getattr(self.dvrl_model, 'forward')):
                try:
                    # Create a simple observation tensor
                    dummy_obs = torch.zeros((1, self.dvrl_model.obs_dim)).to(
                        next(self.dvrl_model.parameters()).device)
                    output = self.dvrl_model.forward(dummy_obs)
                    
                    # Handle different possible output formats
                    if isinstance(output, tuple) and len(output) > 0:
                        samples = output[0]  # Take first element if tuple
                    else:
                        samples = output
                        
                    if isinstance(samples, torch.Tensor):
                        # Reshape if needed to match expected particle count
                        if samples.shape[0] == 1:
                            # If batch dimension is 1, repeat to get desired number of samples
                            samples = samples.repeat(self.n_samples, 1)
                        particles = samples.detach().cpu().numpy()
                except Exception as e:
                    print(f"Error using DVRL forward method: {e}")
            
            # Check if dimensions match
            if particles.shape[1] != dim:
                print(f"Warning: DVRL output dimensions ({particles.shape[1]}) don't match input dimensions ({dim})")
                # Generate fallback particles based on input statistics
                particles = np.random.randn(self.n_samples, dim) * np.std(initial_particles, axis=0)
                particles += np.mean(initial_particles, axis=0)
                
            # Ensure we have the right number of particles
            if len(particles) != self.n_samples:
                # Resample to get correct count
                indices = np.random.choice(len(particles), self.n_samples, replace=len(particles) < self.n_samples)
                particles = particles[indices]
        
        except Exception as e:
            print(f"Failed to generate particles from DVRL, using fallback: {e}")
            # Use initial particles but add some noise to provide variety
            particles = initial_particles.copy()
            particles += np.random.randn(*particles.shape) * 0.2
        
        # Create convergence info
        if return_convergence:
            convergence_info = {
                'iterations': 1,
                'delta_norm_history': [0.0],  # Placeholder history
                'step_size_history': [0.0]    # Placeholder history
            }
            return particles, convergence_info
        else:
            return particles


class StableSVGD20D(SVGD):
    """
    # & Enhanced SVGD with aggressive mode exploration for 20D multi-modal distributions
    """
    def __init__(self, kernel=None, step_size=0.01, n_iter=300, tol=1e-5, 
                 bandwidth_scale=0.5, add_noise=True, noise_level=0.4, 
                 noise_decay=0.97, resample_freq=6, adaptive_step=True, 
                 mode_detection=True, lambda_corr=0.7, verbose=True,
                 target_info=None):
        """
        # & Initialize enhanced SVGD with better multi-modal support for 20D
        """
        # Create default kernel if not provided - with adaptive bandwidth
        if kernel is None:
            kernel = RBFKernel(bandwidth=1.0, adaptive=True)
            
        # Initialize parent class
        super().__init__(kernel=kernel, step_size=step_size, 
                          n_iter=n_iter, tol=tol, verbose=verbose)
        
        # Enhanced parameters for better mode coverage in high dimensions
        self.bandwidth_scale = bandwidth_scale
        self.add_noise = add_noise
        self.noise_level = noise_level  # Increased noise level for 20D
        self.noise_decay = noise_decay  # Slower decay for 20D
        self.resample_freq = resample_freq
        self.adaptive_step = adaptive_step
        self.mode_detection = mode_detection
        self.lambda_corr = lambda_corr  # Higher correlation weight for 20D
        
        # Store target information if available
        self.target_info = target_info
        
        # Initialize mode-related attributes
        self.detected_modes = None
        self.mode_assignments = None
        self._mode_centers = None
        self._mode_covs = None
        self._cholesky_cache = {}
        
        # More aggressive mode balancing for 20D
        self.mode_balance_freq = 5  # More frequent balancing
        self.direct_intervention_freq = 6  # More frequent intervention
        self.missing_mode_threshold = 0.01  # Lower threshold for detecting missing modes
        
        # Enhanced correlation handling
        self.use_mahalanobis = True  # Use Mahalanobis distance for mode assignment
        self.correlation_scale = 1.5  # Higher scaling for correlation-aware updates in 20D
        self.repulsion_factor = 4.0  # Stronger repulsion forces for 20D
        
        # Keep track of iterations run
        self.iterations_run = 0
        
    def initialize_particles(self, particles):
        """
        # & Initialize particles with more uniform mode coverage for 20D
        """
        n_particles, dim = particles.shape
        
        # If we have target info with known modes, use it
        if self.target_info is not None and 'centers' in self.target_info:
            centers = self.target_info['centers']
            covs = self.target_info.get('covs', None)
            
            # Store mode information for later use
            self._mode_centers = centers
            self._mode_covs = covs
            
            # Number of modes
            n_modes = len(centers)
            
            # Distribute particles evenly across all modes
            new_particles = np.zeros((n_particles, dim))
            
            # More uniform distribution than before
            particles_per_mode = [n_particles // n_modes] * n_modes
            
            # Account for rounding
            remainder = n_particles - sum(particles_per_mode)
            for i in range(remainder):
                particles_per_mode[i] += 1
            
            # Initialize mode assignments
            self.mode_assignments = np.zeros(n_particles, dtype=int)
            
            idx = 0
            for i in range(n_modes):
                n_mode = particles_per_mode[i]
                
                # Initialize particles around this mode with correlation structure
                if covs is not None and i < len(covs):
                    try:
                        cov = covs[i]
                        # Add small regularization for numerical stability
                        cov_reg = cov + 1e-5 * np.eye(dim)
                        
                        # Try Cholesky first
                        try:
                            L = np.linalg.cholesky(cov_reg)
                            self._cholesky_cache[i] = L
                            # Generate correlated random samples
                            z = np.random.randn(n_mode, dim)
                            # Use tighter spread for 20D
                            correlated = z @ L.T * 0.4
                            new_particles[idx:idx+n_mode] = centers[i] + correlated
                        except:
                            # If Cholesky fails, use eigendecomposition
                            eigvals, eigvecs = np.linalg.eigh(cov_reg)
                            eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive
                            L = eigvecs @ np.diag(np.sqrt(eigvals))
                            z = np.random.randn(n_mode, dim)
                            correlated = z @ L.T * 0.4
                            new_particles[idx:idx+n_mode] = centers[i] + correlated
                    except:
                        # Ultimate fallback to isotropic
                        new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.5
                else:
                    # Use isotropic normal if no covariance available
                    new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.5
                
                # Assign mode labels
                self.mode_assignments[idx:idx+n_mode] = i
                
                idx += n_mode
            
            # Add extra exploration noise to a subset of particles
            # For 20D, we use more exploration particles
            explore_fraction = 0.15  # 15% of particles get extra noise
            explore_count = int(n_particles * explore_fraction)
            if explore_count > 0:
                explore_indices = np.random.choice(n_particles, explore_count, replace=False)
                new_particles[explore_indices] += np.random.randn(explore_count, dim) * 2.0
            
            return new_particles
        
        # If no target info, just return original particles
        return particles
    
    def _update_mode_assignments(self, particles):
        """
        # & Update mode assignments using Mahalanobis distance when possible for 20D
        """
        if self._mode_centers is None:
            return None, None
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        
        # Initialize or reset mode assignments
        if self.mode_assignments is None or len(self.mode_assignments) != n_particles:
            self.mode_assignments = np.zeros(n_particles, dtype=int)
        
        # Use Mahalanobis distance when covariance is available
        if self.use_mahalanobis and self._mode_covs is not None:
            distances = np.zeros((n_particles, n_modes))
            
            for i, center in enumerate(self._mode_centers):
                if i < len(self._mode_covs):
                    try:
                        # Add regularization for numerical stability
                        cov = self._mode_covs[i] + 1e-5 * np.eye(particles.shape[1])
                        inv_cov = np.linalg.inv(cov)
                        
                        # Compute Mahalanobis distance
                        diff = particles - center
                        for j in range(n_particles):
                            distances[j, i] = np.sqrt(diff[j] @ inv_cov @ diff[j])
                    except:
                        # Fallback to Euclidean
                        diff = particles - center
                        distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
                else:
                    # Use Euclidean if no covariance
                    diff = particles - center
                    distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
        else:
            # Use Euclidean distance
            distances = np.zeros((n_particles, n_modes))
            for i, center in enumerate(self._mode_centers):
                diff = particles - center
                distances[:, i] = np.sqrt(np.sum(diff**2, axis=1))
        
        # Assign each particle to nearest center
        self.mode_assignments = np.argmin(distances, axis=1)
        
        return self._mode_centers, self.mode_assignments
    
    def _direct_mode_intervention(self, particles, iteration):
        """
        # & Directly intervene to maintain mode coverage in 20D
        """
        if self._mode_centers is None:
            return particles
            
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        
        # Update mode assignments first
        self._update_mode_assignments(particles)
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Check for severely underrepresented modes
        target_per_mode = n_particles / n_modes
        
        # More aggressive threshold for 20D
        critically_low = np.where(mode_counts < target_per_mode * self.missing_mode_threshold)[0]
        
        if len(critically_low) > 0:
            # More particles to move for 20D - up to 40% of expected count
            particles_to_move = min(int(target_per_mode * 0.4), 
                                   int(n_particles * 0.08))  # but never more than 8% of total
            
            for mode_idx in critically_low:
                # Find particles to replace from most populated modes
                most_pop_mode = np.argmax(mode_counts)
                if most_pop_mode == mode_idx:
                    # Find next most populated mode
                    temp_counts = mode_counts.copy()
                    temp_counts[most_pop_mode] = 0
                    if np.sum(temp_counts) == 0:
                        continue  # No other modes have particles
                    most_pop_mode = np.argmax(temp_counts)
                    
                source_indices = np.where(self.mode_assignments == most_pop_mode)[0]
                
                # Number of particles to move - more aggressive for 20D
                n_move = min(particles_to_move, len(source_indices), 
                             int(mode_counts[most_pop_mode] - target_per_mode * 0.4))
                
                if n_move > 0:
                    # Select indices to replace
                    move_indices = source_indices[:n_move]
                    
                    # Place at mode center with appropriate correlation-aware noise
                    mode_center = self._mode_centers[mode_idx]
                    
                    # Add correlated noise if available
                    if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                        cov = self._mode_covs[mode_idx]
                        try:
                            # Try Cholesky decomposition
                            if mode_idx in self._cholesky_cache:
                                L = self._cholesky_cache[mode_idx]
                            else:
                                # Add regularization
                                cov_reg = cov + 1e-5 * np.eye(dim)
                                L = np.linalg.cholesky(cov_reg)
                                self._cholesky_cache[mode_idx] = L
                                
                            # Generate correlated samples
                            for i, idx in enumerate(move_indices):
                                particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.4
                                self.mode_assignments[idx] = mode_idx
                        except:
                            # Try eigendecomposition if Cholesky fails
                            try:
                                eigvals, eigvecs = np.linalg.eigh(cov + 1e-5 * np.eye(dim))
                                eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive eigenvalues
                                L = eigvecs @ np.diag(np.sqrt(eigvals))
                                
                                # Generate correlated samples
                                for i, idx in enumerate(move_indices):
                                    z = np.random.randn(dim)
                                    correlated_noise = L @ z * 0.4
                                    particles[idx] = mode_center + correlated_noise
                                    self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to isotropic noise
                                for i, idx in enumerate(move_indices):
                                    particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                    self.mode_assignments[idx] = mode_idx
                    else:
                        # Use isotropic noise if no covariance
                        for i, idx in enumerate(move_indices):
                            particles[idx] = mode_center + np.random.randn(dim) * 0.4
                            self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[most_pop_mode] -= n_move
                    mode_counts[mode_idx] += n_move
        
        # Also check for moderately underrepresented modes (higher threshold)
        # In 20D, we're more aggressive with redistribution
        moderately_low = np.where((mode_counts >= target_per_mode * self.missing_mode_threshold) & 
                                 (mode_counts < target_per_mode * 0.3))[0]  # Lower threshold for 20D
        
        if len(moderately_low) > 0 and iteration < self.n_iter * 0.6:  # More intervention in 20D
            # For moderately underrepresented, move more particles in 20D
            particles_to_move = int(target_per_mode * 0.2)  # 20% of expected
            
            for mode_idx in moderately_low:
                # Only move particles if we're sufficiently below target
                if mode_counts[mode_idx] < target_per_mode * 0.3:
                    # Find particles from most populated mode
                    most_pop_mode = np.argmax(mode_counts)
                    if most_pop_mode == mode_idx or mode_counts[most_pop_mode] < target_per_mode * 1.2:
                        continue  # Skip if not enough excess particles
                        
                    source_indices = np.where(self.mode_assignments == most_pop_mode)[0]
                    
                    # Move a smaller number
                    n_move = min(particles_to_move, len(source_indices),
                                int(mode_counts[most_pop_mode] - target_per_mode))
                    
                    if n_move > 0:
                        # Select indices and move with correlated noise
                        move_indices = source_indices[:n_move]
                        mode_center = self._mode_centers[mode_idx]
                        
                        if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                            try:
                                # Use cached Cholesky if available
                                if mode_idx in self._cholesky_cache:
                                    L = self._cholesky_cache[mode_idx]
                                    for i, idx in enumerate(move_indices):
                                        particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.4
                                        self.mode_assignments[idx] = mode_idx
                                else:
                                    # Try eigendecomposition
                                    eigvals, eigvecs = np.linalg.eigh(self._mode_covs[mode_idx] + 1e-5 * np.eye(dim))
                                    eigvals = np.maximum(eigvals, 1e-6)
                                    L = eigvecs @ np.diag(np.sqrt(eigvals))
                                    for i, idx in enumerate(move_indices):
                                        z = np.random.randn(dim)
                                        particles[idx] = mode_center + L @ z * 0.4
                                        self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to isotropic
                                for i, idx in enumerate(move_indices):
                                    particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                    self.mode_assignments[idx] = mode_idx
                        else:
                            # Use isotropic noise
                            for i, idx in enumerate(move_indices):
                                particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                self.mode_assignments[idx] = mode_idx
                        
                        # Update mode counts
                        mode_counts[most_pop_mode] -= n_move
                        mode_counts[mode_idx] += n_move
        
        return particles
    
    def _compute_svgd_update(self, particles, score_fn, iteration=0):
        """
        # & Compute SVGD update with improved correlation handling for 20D
        """
        n_particles, dim = particles.shape
        
        # Update mode assignments more frequently in 20D
        if self.mode_detection and (iteration == 0 or iteration % 4 == 0):
            self._update_mode_assignments(particles)
        
        # Get score function values with error handling
        try:
            score_values = score_fn(particles)
            
            # Check for NaN or inf values
            if np.any(np.isnan(score_values)) or np.any(np.isinf(score_values)):
                score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
        except:
            # If score function fails, use zero scores
            score_values = np.zeros_like(particles)
        
        # Enhanced correlation guidance using eigendecomposition - more important in 20D
        if self._mode_covs is not None and self.mode_assignments is not None:
            # Pre-compute eigendecompositions for efficiency
            mode_eig_cache = {}
            
            for mode_idx, cov in enumerate(self._mode_covs):
                if mode_idx not in mode_eig_cache:
                    try:
                        eigvals, eigvecs = np.linalg.eigh(cov)
                        eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive
                        mode_eig_cache[mode_idx] = (eigvals, eigvecs)
                    except:
                        # Identity fallback
                        mode_eig_cache[mode_idx] = (np.ones(dim), np.eye(dim))
            
            # Apply scaled correlation guidance to each particle
            # More aggressive for 20D to handle the higher dimensional challenges
            for mode_idx in np.unique(self.mode_assignments):
                if mode_idx >= len(self._mode_covs):
                    continue
                    
                # Get particles in this mode
                mode_mask = self.mode_assignments == mode_idx
                mode_indices = np.where(mode_mask)[0]
                
                if len(mode_indices) > 0 and mode_idx in mode_eig_cache:
                    eigvals, eigvecs = mode_eig_cache[mode_idx]
                    
                    # Apply correlation-aware scaling to score values
                    for idx in mode_indices:
                        # Project score onto eigenvectors
                        proj_score = eigvecs.T @ score_values[idx]
                        
                        # Scale by sqrt of eigenvalues with higher correlation emphasis for 20D
                        # Using higher power (correlation_scale) to emphasize correlation differences
                        scaling = np.sqrt(eigvals) ** self.correlation_scale
                        proj_score = proj_score * scaling
                        
                        # Project back to original space
                        score_values[idx] = eigvecs @ proj_score
        
        # Compute kernel matrix and gradient
        K = self.kernel.evaluate(particles)
        grad_K = self.kernel.gradient(particles)
        
        # Handle numerical issues
        if np.any(np.isnan(K)) or np.any(np.isinf(K)):
            K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
        if np.any(np.isnan(grad_K)) or np.any(np.isinf(grad_K)):
            grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Compute attractive forces
        attractive = np.zeros_like(particles)
        for i in range(n_particles):
            attractive[i] = np.sum(K[i, :, np.newaxis] * score_values, axis=0)
        
        # Compute repulsive forces
        repulsive = np.zeros_like(particles)
        for i in range(n_particles):
            repulsive[i] = np.sum(grad_K[:, i, :], axis=0)
        
        # Dynamic repulsion factor - stronger for 20D, especially early
        if iteration < self.n_iter * 0.2:
            # Very strong early repulsion for 20D
            repulsion_factor = self.repulsion_factor * (1.0 - iteration / (self.n_iter * 0.2) * 0.3)
        elif iteration < self.n_iter * 0.5:
            # Moderate middle-stage repulsion
            repulsion_factor = self.repulsion_factor * 0.8
        else:
            # Gentler late-stage repulsion but still stronger than lower dimensions
            repulsion_factor = self.repulsion_factor * 0.6
        
        # Apply repulsion factor
        repulsive *= repulsion_factor
        
        # More aggressive mode balancing for 20D
        if self.mode_detection and self.mode_assignments is not None:
            unique_modes = np.unique(self.mode_assignments)
            mode_counts = np.bincount(self.mode_assignments, minlength=len(unique_modes))
            
            if len(mode_counts) > 0 and np.any(mode_counts > 0):
                # Compute target counts
                target_counts = np.ones_like(mode_counts) * (n_particles / len(unique_modes))
                
                # Even more aggressive weighting for 20D
                mode_weights = (target_counts / np.maximum(mode_counts, 1)) ** 2.0
                
                # Cap weights to avoid numerical issues - higher caps for 20D
                mode_weights = np.clip(mode_weights, 0.4, 10.0)
                
                # Apply weights based on mode assignment
                for mode_idx, weight in enumerate(mode_weights):
                    mode_mask = self.mode_assignments == mode_idx
                    attractive[mode_mask] *= weight
                    
                    # Scale repulsion inversely in very low-count modes to avoid dispersion
                    if mode_counts[mode_idx] < target_counts[mode_idx] * 0.2:
                        repulsive[mode_mask] *= min(1.0, weight * 0.5)
                    else:
                        repulsive[mode_mask] *= weight
        
        # Combine attractive and repulsive terms
        update = attractive + repulsive
        
        # Handle numerical issues
        if np.any(np.isnan(update)) or np.any(np.isinf(update)):
            update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Clip extreme updates relative to average particle distance
        # More important in 20D where jumps can be larger
        try:
            # Sample-based estimate of average particle distance
            sample_size = min(40, n_particles)
            if sample_size > 1:
                indices = np.random.choice(n_particles, sample_size, replace=False)
                dists = []
                for i in range(sample_size):
                    for j in range(i+1, sample_size):
                        dists.append(np.linalg.norm(particles[indices[i]] - particles[indices[j]]))
                avg_dist = np.mean(dists) if dists else 1.0
            else:
                avg_dist = 1.0
                
            # Clip updates relative to this distance - be more conservative in 20D
            max_norm = avg_dist * 0.1  # Allow up to 10% movement relative to avg distance
            update_norms = np.sqrt(np.sum(update**2, axis=1))
            large_updates = update_norms > max_norm
            if np.any(large_updates):
                scale_factors = max_norm / update_norms[large_updates]
                update[large_updates] *= scale_factors[:, np.newaxis]
        except:
            # Fallback to simple clipping if estimation fails
            update_norms = np.sqrt(np.sum(update**2, axis=1))
            max_norm = 1.0
            large_updates = update_norms > max_norm
            if np.any(large_updates):
                scale_factors = max_norm / update_norms[large_updates]
                update[large_updates] *= scale_factors[:, np.newaxis]
        
        return update
    
    def update(self, particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Enhanced SVGD optimization with aggressive mode exploration for 20D
        """
        # Make a copy of initial particles
        particles = particles.copy()
        n_particles, dim = particles.shape
        
        # Improved initialization
        particles = self.initialize_particles(particles)
        
        # Tracking variables
        curr_step_size = self.step_size
        current_noise = self.noise_level if self.add_noise else 0.0
        delta_norm_history = []
        step_size_history = []
        
        # Set up progress bar if verbose
        iterator = range(self.n_iter)
        if self.verbose:
            try:
                iterator = tqdm(iterator, desc="Stable SVGD 20D")
            except ImportError:
                pass
        
        # Force early mode intervention to ensure coverage - more important in 20D
        if self._mode_centers is not None:
            particles = self._direct_mode_intervention(particles, 0)
        
        # Main optimization loop
        for t in iterator:
            # Frequent direct mode intervention in early iterations
            # More frequent for 20D
            if t < self.n_iter * 0.6 and t % self.direct_intervention_freq == 0:
                particles = self._direct_mode_intervention(particles, t)
                
            # Compute SVGD update
            update = self._compute_svgd_update(particles, score_fn, t)
            
            # Enhanced noise schedule with exploration phase - stronger for 20D
            if current_noise > 0:
                if t < self.n_iter * 0.25:
                    # Very strong early exploration noise for 20D
                    noise_scale = current_noise * 2.0
                    noise = np.random.randn(*particles.shape) * noise_scale
                    update = update + noise
                elif t < self.n_iter * 0.5:
                    # Moderate noise in middle phase
                    noise_scale = current_noise * 1.2
                    if t % 2 == 0:  # Every other iteration
                        noise = np.random.randn(*particles.shape) * noise_scale
                        update = update + noise
                elif t < self.n_iter * 0.75:
                    # Light noise in later phase
                    noise_scale = current_noise * 0.6
                    if t % 3 == 0:  # Every third iteration
                        noise = np.random.randn(*particles.shape) * noise_scale
                        update = update + noise
                
                # Slower noise decay in early iterations for 20D
                if t < self.n_iter * 0.25:
                    current_noise *= self.noise_decay ** 0.3  # Very slow decay early
                else:
                    current_noise *= self.noise_decay
            
            # Apply update with step size
            new_particles = particles + curr_step_size * update
            
            # More frequent mode-balancing resampling in 20D
            if t > 0:
                if t < self.n_iter * 0.3 and t % (self.mode_balance_freq // 2) == 0:
                    # Very frequent in early phase for 20D
                    new_particles = self._mode_balanced_resample(new_particles)
                elif t < self.n_iter * 0.7 and t % self.mode_balance_freq == 0:
                    # Regular frequency in middle phase
                    new_particles = self._mode_balanced_resample(new_particles)
                elif t % (self.mode_balance_freq * 2) == 0:
                    # Less frequent in late phase
                    new_particles = self._mode_balanced_resample(new_particles)
            
            # Compute convergence metric with safety checks
            delta = new_particles - particles
            delta_norm = np.linalg.norm(delta) / (n_particles * dim)
            
            # Handle numerical instability
            if np.isnan(delta_norm) or np.isinf(delta_norm) or delta_norm > 1e10:
                curr_step_size *= 0.1
                if self.verbose:
                    print(f"Unstable update detected! Reducing step size to {curr_step_size:.6f}")
                continue
            
            # Record history
            delta_norm_history.append(delta_norm)
            step_size_history.append(curr_step_size)
            
            # Update particles
            particles = new_particles
            
            # Improved step size decay schedule - slower for 20D
            if self.adaptive_step:
                if t < self.n_iter * 0.4:
                    # Maintain larger steps longer for exploration in 20D
                    curr_step_size = self.step_size / (1.0 + 0.002 * t)
                else:
                    # Faster decay later for refinement
                    curr_step_size = self.step_size / (1.0 + 0.01 * t)
            
            # Check for convergence
            if t > self.n_iter // 2 and delta_norm < self.tol:
                if self.verbose:
                    print(f"Converged after {t+1} iterations. Delta norm: {delta_norm:.6f}")
                self.iterations_run = t + 1
                break
        else:
            self.iterations_run = self.n_iter
            if self.verbose:
                print(f"Maximum iterations reached. Final delta norm: {delta_norm:.6f}")
        
        if return_convergence:
            convergence_info = {
                'delta_norm_history': np.array(delta_norm_history),
                'step_size_history': np.array(step_size_history),
                'iterations_run': self.iterations_run
            }
            return particles, convergence_info
        
        return particles
    
    def _mode_balanced_resample(self, particles):
        """
        # & Enhanced mode-balancing resampling with correlation preservation for 20D
        """
        if self._mode_centers is None or self.mode_assignments is None:
            return particles
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        new_particles = particles.copy()
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Target count per mode - uniform distribution
        target_count = n_particles / n_modes
        
        # Find significantly underrepresented modes
        # More aggressive threshold for 20D (40% of target)
        for mode_idx in range(n_modes):
            # Check if this mode is significantly underrepresented
            if mode_counts[mode_idx] < target_count * 0.4:
                mode_deficit = int(target_count - mode_counts[mode_idx])
                
                # Generate new particles around this mode center
                mode_center = self._mode_centers[mode_idx]
                
                # Find particles from overrepresented modes to replace
                other_modes = np.where(mode_counts > target_count * 1.1)[0]  # Lower threshold for 20D
                if len(other_modes) > 0:
                    # Get particles from most overrepresented mode
                    replace_mode = other_modes[np.argmax(mode_counts[other_modes])]
                    replace_indices = np.where(self.mode_assignments == replace_mode)[0]
                    
                    # Replace a subset of these particles - more aggressive for 20D
                    n_replace = min(mode_deficit, len(replace_indices))
                    replace_indices = replace_indices[:n_replace]
                    
                    # Generate new particles with correlation structure
                    if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                        cov = self._mode_covs[mode_idx]
                        try:
                            # Try using cached Cholesky decomposition
                            if mode_idx in self._cholesky_cache:
                                L = self._cholesky_cache[mode_idx]
                            else:
                                # Compute new decomposition with regularization
                                L = np.linalg.cholesky(cov + 1e-5 * np.eye(dim))
                                self._cholesky_cache[mode_idx] = L
                                
                            # Generate correlated samples - tighter spread for 20D
                            for i, idx in enumerate(replace_indices):
                                new_particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.4
                                self.mode_assignments[idx] = mode_idx
                        except:
                            # Try eigendecomposition if Cholesky fails
                            try:
                                eigvals, eigvecs = np.linalg.eigh(cov + 1e-5 * np.eye(dim))
                                eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive eigenvalues
                                L = eigvecs @ np.diag(np.sqrt(eigvals))
                                
                                # Generate correlated samples
                                for i, idx in enumerate(replace_indices):
                                    noise = np.random.randn(dim)
                                    new_particles[idx] = mode_center + (L @ noise) * 0.4
                                    self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to isotropic noise
                                for i, idx in enumerate(replace_indices):
                                    new_particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                    self.mode_assignments[idx] = mode_idx
                    else:
                        # Use isotropic noise if no covariance information
                        for i, idx in enumerate(replace_indices):
                            new_particles[idx] = mode_center + np.random.randn(dim) * 0.4
                            self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[replace_mode] -= n_replace
                    mode_counts[mode_idx] += n_replace
        
        return new_particles
        
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Public interface for optimizer
        """
        return self.update(initial_particles, score_fn, target_samples, return_convergence)


class StableSVGD20DAdapter:
    """
    # & Adapter for enhanced StableSVGD20D with improved multi-modal support
    """
    def __init__(self, n_iter=300, step_size=0.01, verbose=True, target_info=None):
        self.svgd = StableSVGD20D(
            step_size=step_size,
            n_iter=n_iter,
            verbose=verbose,
            add_noise=True,
            noise_level=0.4,          # Higher noise for 20D exploration
            noise_decay=0.97,         # Slower decay to maintain exploration
            resample_freq=6,          # More frequent resampling
            adaptive_step=True,
            mode_detection=True,
            lambda_corr=0.7,          # Higher correlation emphasis
            target_info=target_info
        )
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        return self.svgd.fit_transform(
            initial_particles, score_fn, target_samples, return_convergence)


# ========================================
# ESCORT 20D Implementation
# ========================================

class ESCORT20D(AdaptiveSVGD):
    """
    # & 20D implementation of ESCORT framework for POMDPs
    # & Optimized for very high-dimensional, multi-modal correlated distributions
    """
    def __init__(self, kernel=None, gswd=None, step_size=0.02, 
                n_iter=300, tol=1e-5, lambda_reg=0.5,  # Increased lambda_reg for 20D
                decay_step_size=True, verbose=True, 
                noise_level=0.3, noise_decay=0.97,  # Higher noise, slower decay for 20D
                target_info=None):
        # Create default kernel if not provided
        if kernel is None:
            kernel = RBFKernel(adaptive=True)
        
        # Create default GSWD if not provided - more projections for 20D
        if gswd is None:
            gswd = GSWD(n_projections=80, projection_method='random', 
                    optimization_steps=15, correlation_aware=True)
        
        # Initialize parent class
        super().__init__(
            kernel=kernel, gswd=gswd, step_size=step_size,
            n_iter=n_iter, tol=tol, lambda_reg=lambda_reg,
            decay_step_size=decay_step_size, verbose=verbose
        )
        
        # Store target information if available
        self.target_info = target_info
        self.detected_modes = None
        self.mode_assignments = None
        self.noise_level = noise_level
        self.noise_decay = noise_decay
        
        # Cache frequently used data
        self._mode_centers = None
        self._mode_covs = None
        self._cholesky_cache = {}
        
        # Store iterations run
        self.iterations_run = 0
        
        # Additional parameters for 20D
        self.mode_balance_freq = 6  # Balance modes more frequently
        self.aggressive_exploration = True  # Enable aggressive exploration
        self.repulsion_factor = 5.0  # Much stronger repulsion in 20D
        self.intervention_freq = 6  # More frequent intervention
        self.correlation_scale = 1.8  # Stronger correlation emphasis for 20D
    
    def _initialize_particles(self, particles):
        """
        # & Optimized particle initialization for 20D with better mode coverage
        """
        n_particles, dim = particles.shape
        
        # Only optimize if we have target info
        if self.target_info is not None and 'centers' in self.target_info:
            centers = self.target_info['centers']
            covs = self.target_info.get('covs', None)
            
            # Cache mode information
            self._mode_centers = centers
            self._mode_covs = covs
            
            # Allocate particles to modes
            n_modes = len(centers)
            new_particles = np.zeros_like(particles)
            
            # Better distribution across modes - bias toward complex modes for 20D
            particles_per_mode = []
            total_allocated = 0
            
            for i in range(n_modes):
                # Allocate more particles to modes with stronger correlation
                if covs is not None and i < len(covs):
                    # Compute correlation strength using off-diagonal elements
                    cov = covs[i]
                    off_diag_sum = np.sum(np.abs(cov - np.diag(np.diag(cov))))
                    # Normalize by maximum possible off-diagonal elements
                    max_off_diag = dim * (dim - 1)
                    relative_complexity = off_diag_sum / max_off_diag
                    
                    # More bias for 20D - allocate 30-50% more to complex correlation modes
                    corr_factor = 1.0 + relative_complexity * 0.5
                    mode_count = int(n_particles / n_modes * corr_factor)
                else:
                    mode_count = n_particles // n_modes
                
                particles_per_mode.append(mode_count)
                total_allocated += mode_count
            
            # Adjust to match exactly n_particles
            while total_allocated > n_particles:
                idx = np.argmax(particles_per_mode)
                particles_per_mode[idx] -= 1
                total_allocated -= 1
                
            while total_allocated < n_particles:
                idx = np.argmin(particles_per_mode)
                particles_per_mode[idx] += 1
                total_allocated += 1
            
            idx = 0
            for i in range(n_modes):
                n_mode = particles_per_mode[i]
                
                # Generate correlated samples efficiently
                try:
                    # Cache Cholesky decompositions
                    if i not in self._cholesky_cache and covs is not None and i < len(covs):
                        # Add small regularization for numerical stability
                        cov_reg = covs[i] + 1e-5 * np.eye(dim)
                        try:
                            L = np.linalg.cholesky(cov_reg)
                            self._cholesky_cache[i] = L
                        except:
                            # If Cholesky fails, use eigendecomposition
                            eigvals, eigvecs = np.linalg.eigh(cov_reg)
                            eigvals = np.maximum(eigvals, 1e-5)  # Ensure positive eigenvalues
                            L = eigvecs @ np.diag(np.sqrt(eigvals))
                            self._cholesky_cache[i] = L
                    
                    if i in self._cholesky_cache:
                        L = self._cholesky_cache[i]
                        # Use vectorized operations with appropriate scaling
                        z = np.random.randn(n_mode, dim)
                        # Tighter scaling for 20D
                        correlated = np.dot(z, L.T) * 0.4
                        new_particles[idx:idx+n_mode] = centers[i] + correlated
                    else:
                        # If no cached decomposition, use isotropic
                        new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.4
                except:
                    # Simple fallback
                    new_particles[idx:idx+n_mode] = centers[i] + np.random.randn(n_mode, dim) * 0.4
                
                idx += n_mode
            
            # Initialize mode assignments
            self.mode_assignments = np.zeros(n_particles, dtype=int)
            start_idx = 0
            for i, n_mode in enumerate(particles_per_mode):
                self.mode_assignments[start_idx:start_idx+n_mode] = i
                start_idx += n_mode
            
            self.detected_modes = centers
            
            # For 20D, add some exploration particles with larger noise
            explore_count = int(n_particles * 0.1)  # 10% exploration
            if explore_count > 0:
                explore_indices = np.random.choice(n_particles, explore_count, replace=False)
                new_particles[explore_indices] += np.random.randn(explore_count, dim) * 2.5
            
            return new_particles
        
        return particles
    
    def _update_mode_assignments(self, particles):
        """
        # & Efficiently update mode assignments for 20D with Mahalanobis distance
        """
        n_particles = len(particles)
        
        # If we already have mode centers, assign to nearest considering correlation
        if self._mode_centers is not None:
            # Initialize or reset mode assignments if needed
            if self.mode_assignments is None or len(self.mode_assignments) != n_particles:
                self.mode_assignments = np.zeros(n_particles, dtype=int)
            
            # Use Mahalanobis distance when possible for better correlation-aware assignment
            if self._mode_covs is not None:
                # For each particle, compute Mahalanobis distance to each mode
                # Processing in batches for efficiency in 20D
                n_modes = len(self._mode_centers)
                distances = np.zeros((n_particles, n_modes))
                
                batch_size = 100  # Process in batches
                for batch_start in range(0, n_particles, batch_size):
                    batch_end = min(batch_start + batch_size, n_particles)
                    batch_particles = particles[batch_start:batch_end]
                    
                    for i, center in enumerate(self._mode_centers):
                        if i < len(self._mode_covs):
                            try:
                                # Try to compute inverse covariance
                                cov = self._mode_covs[i]
                                # Add small regularization for numerical stability
                                cov_reg = cov + 1e-5 * np.eye(cov.shape[0])
                                inv_cov = np.linalg.inv(cov_reg)
                                
                                # Compute Mahalanobis distance for this batch
                                diff = batch_particles - center
                                for j in range(len(batch_particles)):
                                    d = diff[j]
                                    distances[batch_start + j, i] = np.sqrt(d @ inv_cov @ d)
                            except:
                                # Fallback to Euclidean if inverse fails
                                diff = batch_particles - center
                                distances[batch_start:batch_end, i] = np.sqrt(np.sum(diff**2, axis=1))
                        else:
                            # Fallback to Euclidean for modes without covariance
                            diff = batch_particles - center
                            distances[batch_start:batch_end, i] = np.sqrt(np.sum(diff**2, axis=1))
            else:
                # Use Euclidean distance if no covariance info
                distances = np.zeros((n_particles, len(self._mode_centers)))
                
                # Process in batches for efficiency in 20D
                batch_size = 100
                for batch_start in range(0, n_particles, batch_size):
                    batch_end = min(batch_start + batch_size, n_particles)
                    batch_particles = particles[batch_start:batch_end]
                    
                    for i, center in enumerate(self._mode_centers):
                        diff = batch_particles - center
                        distances[batch_start:batch_end, i] = np.sqrt(np.sum(diff**2, axis=1))
            
            # Assign to closest center
            self.mode_assignments = np.argmin(distances, axis=1)
            
            return True
        
        return False
    
    def _compute_svgd_update(self, particles, score_fn, iteration=0):
        """
        # & Compute SVGD update with better correlation awareness and mode balancing for 20D
        """
        n_particles, dim = particles.shape
        
        # Update mode assignments periodically - more frequently in early iterations
        # More frequent updates for 20D
        update_freq = max(3, min(6, self.n_iter // 50))
        if iteration == 0 or (iteration % update_freq == 0 and self._mode_centers is not None):
            self._update_mode_assignments(particles)
        
        # Get score values with better error handling
        try:
            score_values = score_fn(particles)
            
            # Check for NaN/Inf and replace with zeros
            if np.any(np.isnan(score_values)) or np.any(np.isinf(score_values)):
                score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
        except:
            # If score function fails, use zero scores
            score_values = np.zeros_like(particles)
        
        # Apply correlation guidance and mode-specific updates
        # More aggressive correlation guidance for 20D
        if self._mode_covs is not None and self.mode_assignments is not None:
            # Compute eigendecompositions for each mode once
            mode_eigen_cache = {}
            
            # Pre-compute eigendecompositions for each mode covariance
            for mode_idx, cov in enumerate(self._mode_covs):
                if mode_idx not in mode_eigen_cache:
                    try:
                        # Compute eigendecomposition once per mode
                        eigvals, eigvecs = np.linalg.eigh(cov)
                        # Ensure positivity for numerical stability
                        eigvals = np.maximum(eigvals, 1e-6)
                        mode_eigen_cache[mode_idx] = (eigvals, eigvecs)
                    except:
                        # If decomposition fails, use identity
                        mode_eigen_cache[mode_idx] = (np.ones(dim), np.eye(dim))
            
            # Apply correlation-aware updates for each mode - stronger in 20D
            for mode_idx in np.unique(self.mode_assignments):
                if mode_idx >= len(self._mode_covs):
                    continue
                    
                # Get particles in this mode
                mode_mask = self.mode_assignments == mode_idx
                mode_indices = np.where(mode_mask)[0]
                
                if len(mode_indices) > 0:
                    # Get eigenvectors and eigenvalues for this mode
                    if mode_idx in mode_eigen_cache:
                        eigvals, eigvecs = mode_eigen_cache[mode_idx]
                        
                        # Apply stronger correlation-aware gradient scaling for 20D
                        for idx in mode_indices:
                            # Project score onto eigenvectors
                            proj_score = np.dot(eigvecs.T, score_values[idx])
                            
                            # Scale by sqrt of eigenvalues - stronger effect for 20D
                            scale_factor = self.correlation_scale  # Amplify correlation effect more in 20D
                            proj_score = proj_score * (np.sqrt(eigvals) ** scale_factor)
                            
                            # Project back
                            score_values[idx] = np.dot(eigvecs, proj_score)
        
        # Compute kernel matrix and gradient
        K = self.kernel.evaluate(particles)
        grad_K = self.kernel.gradient(particles)
        
        # Check for NaN/Inf values and clean up
        if np.any(np.isnan(K)) or np.any(np.isinf(K)):
            K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
        if np.any(np.isnan(grad_K)) or np.any(np.isinf(grad_K)):
            grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Compute attractive forces
        attractive = np.zeros_like(particles)
        for i in range(n_particles):
            attractive[i] = np.sum(K[i, :, np.newaxis] * score_values, axis=0)
        
        # Compute repulsive forces
        repulsive = np.zeros_like(particles)
        for i in range(n_particles):
            repulsive[i] = np.sum(grad_K[:, i, :], axis=0)
        
        # Dynamic repulsion factor based on iteration - much stronger for 20D
        # Higher dimensions need stronger repulsion to prevent mode collapse
        if iteration < self.n_iter * 0.3:
            # Very strong repulsion in early iterations
            repulsion_factor = self.repulsion_factor * (1.0 - 0.4 * iteration / (self.n_iter * 0.3))
            repulsion_factor = max(3.0, repulsion_factor)  # Never below 3.0 in early phase
        elif iteration < self.n_iter * 0.6:
            # Moderate repulsion in middle iterations
            repulsion_factor = 3.0
        else:
            # Lower repulsion in final iterations for refinement
            repulsion_factor = 2.5  # Still much higher than lower dimensions
        
        # Apply repulsion factor
        repulsive *= repulsion_factor
        
        # Dynamic mode balancing - more aggressive in 20D
        if self.mode_assignments is not None:
            unique_modes = np.unique(self.mode_assignments)
            mode_counts = np.bincount(self.mode_assignments, minlength=len(unique_modes))
            
            if len(mode_counts) > 0 and np.any(mode_counts > 0):
                # Compute ideal count
                ideal_count = n_particles / len(unique_modes)
                
                # More aggressive balancing with higher power for 20D
                mode_weights = (ideal_count / np.maximum(mode_counts, 1)) ** 2.5
                # Cap weights to avoid extreme values - higher caps for 20D
                mode_weights = np.clip(mode_weights, 0.4, 10.0)
                
                # Apply weights based on mode assignment
                particle_weights = mode_weights[self.mode_assignments]
                
                # Apply weights to forces
                attractive *= particle_weights[:, np.newaxis]
                repulsive *= particle_weights[:, np.newaxis]
        
        # Combine forces
        update = attractive + repulsive
        
        # Normalize update to avoid extremely large steps
        update_norm = np.linalg.norm(update)
        if update_norm > 1e-10:  # Avoid division by zero
            avg_particle_dist = 0
            if n_particles > 1:
                # Estimate average distance between particles
                sample_size = min(n_particles, 50)  # Smaller sample for 20D
                sampled_indices = np.random.choice(n_particles, sample_size, replace=False)
                dists = []
                for i in range(sample_size):
                    for j in range(i+1, sample_size):
                        dists.append(np.linalg.norm(particles[sampled_indices[i]] - particles[sampled_indices[j]]))
                avg_particle_dist = np.mean(dists) if dists else 1.0
            else:
                avg_particle_dist = 1.0
                
            # Scale update relative to particle distances - more conservative for 20D
            scale_factor = avg_particle_dist * 0.08  # Allow 8% movement relative to average distance
            
            # Apply update scaling if norm is too large
            if update_norm > scale_factor:
                update = update * (scale_factor / update_norm)
        
        return update
    
    def _direct_mode_intervention(self, particles, iteration):
        """
        # & Directly intervene to maintain mode coverage in difficult 20D
        """
        if self._mode_centers is None:
            return particles
            
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        
        # Only apply direct intervention periodically and early in optimization
        # For 20D, we apply intervention longer into the process
        if iteration > self.n_iter * 0.7:
            return particles
            
        # Update mode assignments first
        self._update_mode_assignments(particles)
            
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Check for severely underrepresented modes
        target_per_mode = n_particles / n_modes
        
        # More aggressive threshold for 20D
        critically_low = np.where(mode_counts < target_per_mode * 0.07)[0]  # Lower threshold for 20D
        
        if len(critically_low) > 0:
            # More particles to move for 20D - up to 40% of expected count
            particles_to_move = min(int(target_per_mode * 0.4), 
                                   int(n_particles * 0.1))  # But never more than 10% of total
            
            for mode_idx in critically_low:
                # Find particles to replace from most populated modes
                most_pop_mode = np.argmax(mode_counts)
                if most_pop_mode == mode_idx:
                    # Find next most populated mode
                    temp_counts = mode_counts.copy()
                    temp_counts[most_pop_mode] = 0
                    if np.sum(temp_counts) == 0:
                        continue  # No other modes have particles
                    most_pop_mode = np.argmax(temp_counts)
                    
                source_indices = np.where(self.mode_assignments == most_pop_mode)[0]
                
                # Number of particles to move - more aggressive for 20D
                n_move = min(particles_to_move, len(source_indices), 
                             int(mode_counts[most_pop_mode] - target_per_mode * 0.4))
                
                if n_move > 0:
                    # Select indices to replace
                    move_indices = source_indices[:n_move]
                    
                    # Place at mode center with appropriate correlation-aware noise
                    mode_center = self._mode_centers[mode_idx]
                    
                    # Add correlated noise if available
                    if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                        cov = self._mode_covs[mode_idx]
                        try:
                            # Use cached Cholesky or compute it
                            if mode_idx in self._cholesky_cache:
                                L = self._cholesky_cache[mode_idx]
                            else:
                                # Add regularization
                                cov_reg = cov + 1e-5 * np.eye(dim)
                                L = np.linalg.cholesky(cov_reg)
                                self._cholesky_cache[mode_idx] = L
                                
                            # Generate correlated samples
                            for i, idx in enumerate(move_indices):
                                particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.4
                                # Update mode assignment
                                self.mode_assignments[idx] = mode_idx
                        except:
                            # Fallback to isotropic
                            for i, idx in enumerate(move_indices):
                                particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                # Update mode assignment
                                self.mode_assignments[idx] = mode_idx
                    else:
                        # Use isotropic noise
                        for i, idx in enumerate(move_indices):
                            particles[idx] = mode_center + np.random.randn(dim) * 0.4
                            # Update mode assignment
                            self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[most_pop_mode] -= n_move
                    mode_counts[mode_idx] += n_move
        
        # Also check for moderately underrepresented modes
        moderately_low = np.where((mode_counts >= target_per_mode * 0.07) & 
                               (mode_counts < target_per_mode * 0.3))[0]
        
        if len(moderately_low) > 0 and iteration % 3 == 0:  # Only check every 3 iterations
            # Moderate intervention for these modes
            particles_to_move = int(target_per_mode * 0.2)  # 20% of expected
            
            for mode_idx in moderately_low:
                # Find most populated mode
                most_pop_mode = np.argmax(mode_counts)
                if most_pop_mode == mode_idx or mode_counts[most_pop_mode] < target_per_mode * 1.2:
                    continue  # Skip if not enough excess particles
                
                source_indices = np.where(self.mode_assignments == most_pop_mode)[0]
                
                # Move a smaller number
                n_move = min(particles_to_move, len(source_indices),
                              int(mode_counts[most_pop_mode] - target_per_mode))
                
                if n_move > 0:
                    # Select indices and move with correlated noise
                    move_indices = source_indices[:n_move]
                    mode_center = self._mode_centers[mode_idx]
                    
                    if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                        try:
                            # Use cached Cholesky if available
                            if mode_idx in self._cholesky_cache:
                                L = self._cholesky_cache[mode_idx]
                                for i, idx in enumerate(move_indices):
                                    particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.4
                                    self.mode_assignments[idx] = mode_idx
                            else:
                                # Try eigendecomposition
                                eigvals, eigvecs = np.linalg.eigh(self._mode_covs[mode_idx] + 1e-5 * np.eye(dim))
                                eigvals = np.maximum(eigvals, 1e-6)
                                L = eigvecs @ np.diag(np.sqrt(eigvals))
                                for i, idx in enumerate(move_indices):
                                    z = np.random.randn(dim)
                                    particles[idx] = mode_center + L @ z * 0.4
                                    self.mode_assignments[idx] = mode_idx
                        except:
                            # Fallback to isotropic
                            for i, idx in enumerate(move_indices):
                                particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                self.mode_assignments[idx] = mode_idx
                    else:
                        # Use isotropic noise
                        for i, idx in enumerate(move_indices):
                            particles[idx] = mode_center + np.random.randn(dim) * 0.4
                            self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[most_pop_mode] -= n_move
                    mode_counts[mode_idx] += n_move
                    
        return particles
    
    def _mode_balanced_resample(self, particles):
        """
        # & Enhanced mode-aware resampling for 20D
        """
        if self._mode_centers is None or self.mode_assignments is None:
            return particles
        
        n_particles = len(particles)
        n_modes = len(self._mode_centers)
        dim = particles.shape[1]
        new_particles = particles.copy()
        
        # Count particles per mode
        mode_counts = np.bincount(self.mode_assignments, minlength=n_modes)
        
        # Target count per mode (uniform distribution)
        target_count = n_particles / n_modes
        
        # Find underrepresented modes - use a lower threshold for 20D (40%)
        for mode_idx in range(n_modes):
            # If this mode has significantly too few particles
            if mode_counts[mode_idx] < target_count * 0.4:
                mode_deficit = int(target_count - mode_counts[mode_idx])
                
                # Find particles from overrepresented modes to replace
                other_modes = np.where(mode_counts > target_count * 1.1)[0]
                if len(other_modes) > 0:
                    # Get particles from most overrepresented mode
                    replace_mode = other_modes[np.argmax(mode_counts[other_modes])]
                    replace_indices = np.where(self.mode_assignments == replace_mode)[0]
                    
                    # Replace a subset of these particles
                    n_replace = min(mode_deficit, len(replace_indices))
                    replace_indices = replace_indices[:n_replace]
                    
                    # Generate new particles for underrepresented mode
                    mode_center = self._mode_centers[mode_idx]
                    
                    # Try to use covariance structure if available
                    if self._mode_covs is not None and mode_idx < len(self._mode_covs):
                        cov = self._mode_covs[mode_idx]
                        try:
                            # Use cached Cholesky if available
                            if mode_idx in self._cholesky_cache:
                                L = self._cholesky_cache[mode_idx]
                            else:
                                # Generate decomposition
                                L = np.linalg.cholesky(cov + 1e-5 * np.eye(dim))
                                self._cholesky_cache[mode_idx] = L
                                
                            # Generate correlated samples
                            for i, idx in enumerate(replace_indices):
                                new_particles[idx] = mode_center + np.random.randn(dim) @ L.T * 0.4
                                self.mode_assignments[idx] = mode_idx
                        except:
                            # If Cholesky fails, use eigendecomposition
                            try:
                                eigvals, eigvecs = np.linalg.eigh(cov)
                                eigvals = np.maximum(eigvals, 1e-6)  # Ensure positive eigenvalues
                                L = eigvecs @ np.diag(np.sqrt(eigvals))
                                for i, idx in enumerate(replace_indices):
                                    noise = np.random.randn(dim)
                                    new_particles[idx] = mode_center + (L @ noise) * 0.4
                                    self.mode_assignments[idx] = mode_idx
                            except:
                                # Fallback to isotropic
                                for i, idx in enumerate(replace_indices):
                                    new_particles[idx] = mode_center + np.random.randn(dim) * 0.4
                                    self.mode_assignments[idx] = mode_idx
                    else:
                        # Use isotropic noise if no covariance
                        for i, idx in enumerate(replace_indices):
                            new_particles[idx] = mode_center + np.random.randn(dim) * 0.4
                            self.mode_assignments[idx] = mode_idx
                    
                    # Update mode counts
                    mode_counts[replace_mode] -= n_replace
                    mode_counts[mode_idx] += n_replace
        
        return new_particles
    
    def update(self, particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Run enhanced ESCORT optimization for 20D
        # & With parameters specially tuned for very high-dimensional spaces
        """
        # Initialize
        particles = particles.copy()
        n_particles, dim = particles.shape
        
        # Better initialization with mode coverage
        particles = self._initialize_particles(particles)
        
        # Tracking variables
        delta_norm_history = []
        step_size_history = []
        curr_step_size = self.step_size
        current_noise = self.noise_level
        
        # Prepare GSWD if target samples are provided
        if target_samples is not None and self.lambda_reg > 0:
            self.gswd.fit(target_samples, particles)
        
        # Setup progress bar
        iterator = range(self.n_iter)
        if self.verbose:
            try:
                iterator = tqdm(iterator, desc="ESCORT 20D")
            except ImportError:
                pass
        
        # Main update loop
        for t in iterator:
            # Aggressive early exploration
            if t < self.n_iter * 0.3 and t % self.intervention_freq == 0:
                # Directly enforce mode coverage periodically
                particles = self._direct_mode_intervention(particles, t)
            
            # Compute SVGD update with better correlation awareness
            svgd_update = self._compute_svgd_update(particles, score_fn, t)
            
            # Add GSWD regularization - apply every 3rd iteration for efficiency in 20D
            if target_samples is not None and self.lambda_reg > 0 and (t % 3 == 0):
                try:
                    # Update gswd regularization
                    gswd_reg = self.gswd.get_regularizer(target_samples, particles, self.lambda_reg)
                    update = svgd_update + gswd_reg
                except Exception as e:
                    update = svgd_update
            else:
                update = svgd_update
            
            # Add noise with better decaying schedule for 20D
            if current_noise > 0:
                if t < self.n_iter * 0.3:
                    # Strong noise early on - every iteration for 20D
                    noise = np.random.randn(*particles.shape) * current_noise * 2.2
                    update = update + noise
                elif t % 2 == 0 and t < self.n_iter * 0.6:
                    # Moderate noise in middle phase - more frequent for 20D
                    noise = np.random.randn(*particles.shape) * current_noise * 1.2
                    update = update + noise
                elif t % 4 == 0 and t < self.n_iter * 0.8:
                    # Light noise in later phase
                    noise = np.random.randn(*particles.shape) * current_noise * 0.6
                    update = update + noise
                
                # Better noise decay schedule for 20D
                if t < self.n_iter * 0.3:
                    # Very slow decay in early iterations for 20D
                    current_noise *= self.noise_decay ** 0.3
                else:
                    # Normal decay later
                    current_noise *= self.noise_decay
            
            # Apply update
            new_particles = particles + curr_step_size * update
            
            # Mode-based resampling - more frequent in 20D
            if t > 0 and t % self.mode_balance_freq == 0:
                # First update mode assignments
                self._update_mode_assignments(new_particles)
                # Then rebalance
                new_particles = self._mode_balanced_resample(new_particles)
            
            # Check convergence
            delta = new_particles - particles
            delta_norm = np.linalg.norm(delta) / n_particles
            delta_norm_history.append(delta_norm)
            step_size_history.append(curr_step_size)
            
            # Update particles
            particles = new_particles
            
            # Step size decay - gentler for 20D
            if self.decay_step_size:
                if t < self.n_iter * 0.4:
                    # Maintain larger steps initially longer for 20D
                    curr_step_size = self.step_size / (1.0 + 0.002 * t)
                else:
                    # Moderate decay later
                    curr_step_size = self.step_size / (1.0 + 0.007 * t)
            
            # Early stopping with less frequent checking for efficiency
            if t > self.n_iter * 0.7 and t % 10 == 0 and delta_norm < self.tol:
                if self.verbose:
                    print(f"Converged after {t+1} iterations. Delta norm: {delta_norm:.6f}")
                self.iterations_run = t + 1
                break
        
        # Update iterations run if didn't break early
        else:
            self.iterations_run = self.n_iter
            if self.verbose:
                print(f"Maximum iterations reached. Final delta norm: {delta_norm:.6f}")
        
        if return_convergence:
            convergence_info = {
                'delta_norm_history': np.array(delta_norm_history),
                'step_size_history': np.array(step_size_history),
                'iterations_run': self.iterations_run
            }
            return particles, convergence_info
        
        return particles
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, 
                     return_convergence=False, reset=True):
        """
        # & Run the optimizer on initial particles
        """
        if reset:
            self.detected_modes = None
            self.mode_assignments = None
            self._cholesky_cache = {}
            
            # Set up mode information from target_info
            if self.target_info is not None:
                self._mode_centers = self.target_info.get('centers', None)
                self._mode_covs = self.target_info.get('covs', None)
            else:
                self._mode_centers = None
                self._mode_covs = None
        
        return self.update(initial_particles, score_fn, target_samples, return_convergence)


class ESCORT20DAdapter:
    """
    # & Adapter for ESCORT20D to match interface with other methods
    """
    def __init__(self, n_iter=300, step_size=0.02, verbose=True, target_info=None):
        self.escort = ESCORT20D(
            step_size=step_size,
            n_iter=n_iter,
            verbose=verbose,
            noise_level=0.3,  # Higher noise for 20D
            noise_decay=0.97,  # Slower decay for 20D
            lambda_reg=0.5,    # Higher regularization for 20D
            target_info=target_info
        )
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        return self.escort.fit_transform(
            initial_particles, score_fn, target_samples, return_convergence)


# ========================================
# Visualization Functions for 20D
# ========================================

def visualize_results_with_error_bars(mean_results_df, all_results_df, all_particles, 
                                    all_convergence, target_gmm):
    """
    # & Visualize experiment results for 20D
    # &
    # & Args:
    # &     target_gmm: Target GMM distribution
    # &     target_samples: Samples from target distribution
    # &     particles_dict: Dictionary of particles from each method
    # &     convergence_dict: Dictionary of convergence info for each method
    # &     results_df: DataFrame with evaluation metrics
    """
    # Create directory for plots
    plots_dir = os.path.join(SCRIPT_DIR, "plots_20d_multiseed")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Get methods to visualize
    methods = list(mean_results_df.index)
    n_methods = len(methods)
    
    # Generate target samples for visualization
    np.random.seed(42)  # Fixed seed for consistent visualization
    n_viz_samples = 2000
    target_samples = target_gmm.sample(n_viz_samples)
    
    # 1. Metrics comparison with error bars
    try:
        print("Generating metrics comparison with error bars...")
        plt.figure(figsize=(18, 15))
        
        # Extract metrics for plotting
        metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
                  'Mode Coverage', 'Correlation Error', 'ESS', 'Sliced Wasserstein']
        
        # Create bar plots for each metric with error bars
        for i, metric in enumerate(metrics):
            plt.subplot(3, 3, i+1)
            
            # Extract means and standard errors
            means = [mean_results_df.loc[method, f"{metric}_mean"] for method in methods]
            errors = [mean_results_df.loc[method, f"{metric}_se"] for method in methods]
            
            # For KL metrics, cap large values for better visualization
            if 'KL' in metric:
                means = [min(m, 20.0) for m in means]
                errors = [e if m < 20.0 else 0 for m, e in zip(means, errors)]
            
            # Create color mapping
            colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown']
            
            # Create bar plot with error bars
            bars = plt.bar(methods, means, 
                          color=[colors[methods.index(m) % len(colors)] for m in methods],
                          yerr=errors, capsize=10, alpha=0.7)
            
            # Add value annotations
            for j, bar in enumerate(bars):
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + errors[j] + 0.01,
                        f'{means[j]:.4f} ± {errors[j]:.4f}', ha='center', va='bottom', rotation=45,
                        fontsize=9)
            
            plt.title(metric)
            plt.xticks(rotation=45)
            plt.grid(axis='y', alpha=0.3)
            
            # For Mode Coverage and ESS, higher is better
            if metric in ['Mode Coverage', 'ESS']:
                plt.ylim(0, 1.1)
                plt.title(f"{metric} (higher is better)")
            else:
                plt.title(f"{metric} (lower is better)")
        
        # Add runtime comparison
        plt.subplot(3, 3, 8)
        runtime_means = [mean_results_df.loc[method, f"Runtime (s)_mean"] for method in methods]
        runtime_errors = [mean_results_df.loc[method, f"Runtime (s)_se"] for method in methods]
        
        bars = plt.bar(methods, runtime_means, 
                      color=[colors[methods.index(m) % len(colors)] for m in methods],
                      yerr=runtime_errors, capsize=10, alpha=0.7)
        
        for j, bar in enumerate(bars):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + runtime_errors[j] + 0.01,
                    f'{runtime_means[j]:.2f}s', ha='center', va='bottom', rotation=0,
                    fontsize=9)
        
        plt.title("Runtime (seconds)")
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "metrics_comparison_with_errors_20d.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error generating metrics comparison: {e}")
        traceback.print_exc()
    
    # 2. Box plots showing distribution of results across runs
    try:
        print("Generating boxplots for metric distributions across runs...")
        plt.figure(figsize=(18, 15))
        
        for i, metric in enumerate(metrics):
            plt.subplot(3, 3, i+1)
            
            # Create box plot
            box_data = [all_results_df[all_results_df['Method'] == method][metric].values 
                       for method in methods]
            
            plt.boxplot(box_data, labels=methods, patch_artist=True,
                      boxprops=dict(facecolor='lightblue', color='blue'),
                      whiskerprops=dict(color='blue'),
                      capprops=dict(color='blue'),
                      medianprops=dict(color='red'))
            
            plt.title(f"{metric} Distribution Across Different Initializations")
            plt.xticks(rotation=45)
            plt.grid(axis='y', alpha=0.3)
            
            # For Mode Coverage and ESS, higher is better
            if metric in ['Mode Coverage', 'ESS']:
                plt.ylim(0, 1.1)
        
        # Add runtime boxplot
        plt.subplot(3, 3, 8)
        runtime_data = [all_results_df[all_results_df['Method'] == method]['Runtime (s)'].values 
                       for method in methods]
        
        plt.boxplot(runtime_data, labels=methods, patch_artist=True,
                  boxprops=dict(facecolor='lightblue', color='blue'),
                  whiskerprops=dict(color='blue'),
                  capprops=dict(color='blue'),
                  medianprops=dict(color='red'))
        
        plt.title("Runtime (s) Distribution Across Different Initializations")
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "metrics_boxplots_20d.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error generating boxplots: {e}")
        traceback.print_exc()
    
    # 3. PCA visualization with multiple runs
    try:
        print("Generating PCA visualizations of multiple runs...")
        # Fit PCA on the target samples
        pca = PCA(n_components=2)
        pca.fit(target_samples)
        
        # Project target samples
        target_proj = pca.transform(target_samples)
        
        plt.figure(figsize=(18, 15))
        
        # Subplot for each method
        for i, method_name in enumerate(methods):
            plt.subplot(2, 2, i+1)
            
            # Plot target as heatmap
            plt.hist2d(target_proj[:, 0], target_proj[:, 1], bins=50, cmap='Blues', alpha=0.3)
            
            # Plot each run for this method with different transparency
            particles_runs = all_particles[method_name]
            
            # Use different colors for different runs
            colors = ['red', 'green', 'purple', 'orange', 'brown', 'pink', 'gray', 'olive', 'cyan']
            
            for j, particles in enumerate(particles_runs):
                # Project particles to 2D using same PCA
                particles_proj = pca.transform(particles)
                
                # Plot with run-specific color and transparency
                alpha = 0.8 if j == 0 else 0.3  # Make first run more visible
                plt.scatter(particles_proj[:, 0], particles_proj[:, 1], 
                          c=colors[j % len(colors)], alpha=alpha, s=10, 
                          label=f"Run {j+1}" if j < 3 else None)  # Only show first 3 in legend
            
            if len(particles_runs) > 3:
                plt.scatter([], [], c='black', alpha=0.3, s=10, label=f"+ {len(particles_runs)-3} more runs")
                
            # Add title with metrics (mean ± SE)
            mode_coverage = mean_results_df.loc[method_name, 'Mode Coverage']
            corr_error = mean_results_df.loc[method_name, 'Correlation Error']
            
            plt.title(f"{method_name} - Multiple Runs\n"
                    f"Mode Coverage: {mode_coverage}, Corr Error: {corr_error}")
            
            plt.xlabel("PC1")
            plt.ylabel("PC2")
            plt.legend(loc='upper right')
            plt.grid(alpha=0.2)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "pca_multiple_runs_20d.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error generating PCA visualizations: {e}")
        traceback.print_exc()
    
    # 4. Performance by initialization type
    try:
        initializations = sorted(all_results_df['Initialization'].unique())
        if len(initializations) > 1:
            print("Generating performance by initialization type...")
            plt.figure(figsize=(18, 15))
            
            for i, metric in enumerate(['MMD', 'Mode Coverage', 'Correlation Error']):
                plt.subplot(3, 1, i+1)
                
                # Prepare data for grouped bar chart
                data = []
                for method in methods:
                    method_data = []
                    method_errors = []
                    for init in initializations:
                        df_subset = all_results_df[(all_results_df['Method'] == method) & 
                                                 (all_results_df['Initialization'] == init)]
                        if not df_subset.empty:
                            method_data.append(df_subset[metric].mean())
                            method_errors.append(scipy.stats.sem(df_subset[metric]))
                        else:
                            method_data.append(0)
                            method_errors.append(0)
                    data.append((method, method_data, method_errors))
                
                # Create grouped bar chart
                bar_width = 0.15
                r = np.arange(len(initializations))
                colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown']
                
                for j, (method, values, errors) in enumerate(data):
                    position = [x + bar_width * j for x in r]
                    bars = plt.bar(position, values, bar_width, 
                                  label=method, 
                                  color=colors[j % len(colors)],
                                  yerr=errors, capsize=5)
                
                # Add labels and legend
                plt.xlabel('Initialization Type')
                plt.ylabel(metric)
                plt.title(f'Performance by Initialization Type - {metric}')
                plt.xticks([r + bar_width * (len(methods) - 1) / 2 for r in range(len(initializations))], 
                          initializations, rotation=45)
                plt.legend()
                plt.grid(axis='y', alpha=0.3)
                
                # For Mode Coverage, higher is better
                if metric == 'Mode Coverage':
                    plt.ylim(0, 1.1)
                    
            plt.tight_layout()
            plt.savefig(os.path.join(plots_dir, "performance_by_initialization_20d.png"), dpi=300)
            plt.close()
    except Exception as e:
        print(f"Error generating initialization comparison: {e}")
        traceback.print_exc()
    
    # 5. Create a summary table
    try:
        print("Generating summary table...")
        plt.figure(figsize=(15, 8))
        plt.axis('off')
        
        # Create a summary table with mean ± SE
        table_data = []
        table_headers = ['Method', 'Mode Coverage', 'Corr. Error', 'ESS', 'MMD', 'Runtime (s)']
        table_data.append(table_headers)
        
        for method in methods:
            table_data.append([
                method, 
                mean_results_df.loc[method, 'Mode Coverage'], 
                mean_results_df.loc[method, 'Correlation Error'],
                mean_results_df.loc[method, 'ESS'],
                mean_results_df.loc[method, 'MMD'],
                mean_results_df.loc[method, 'Runtime (s)']
            ])
        
        table = plt.table(cellText=table_data, loc='center', cellLoc='center', 
                         colWidths=[0.25, 0.15, 0.15, 0.15, 0.15, 0.15])
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1, 1.8)
        plt.title("Summary Metrics - 20D Evaluation (Mean ± SE)", fontsize=16, pad=20)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "summary_table_20d.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error generating summary table: {e}")
        traceback.print_exc()
    
    print(f"All visualizations saved to {plots_dir}")


# ========================================
# Experiment Functions
# ========================================

# def run_experiment(methods_to_run=None, n_iter=300, step_size=0.01, verbose=True):
#     """
#     # & Run experiment comparing different methods on 20D distribution
#     # &
#     # & Args:
#     # &     methods_to_run (list): Methods to evaluate
#     # &     n_iter (int): Number of iterations
#     # &     step_size (float): Step size for updates
#     # &     verbose (bool): Whether to display progress
#     # &
#     # & Returns:
#     # &     tuple: (results_df, target_distribution, particles_dict, convergence_dict)
#     """
#     print("Starting 20D GMM evaluation experiment...")
    
#     if methods_to_run is None:
#         methods_to_run = ['ESCORT20D', 'SVGD', 'DVRL', 'SIR']
    
#     # Create target distribution
#     target_gmm = HighlyCorrelated20DGMMDistribution()
    
#     # Generate target samples
#     n_particles = 1000
#     target_samples = target_gmm.sample(n_particles)
    
#     # Create initial particles (random for fair comparison)
#     np.random.seed(42)  # For reproducibility
#     initial_particles = np.random.randn(n_particles, 20) * 3
    
#     # Score function for target distribution
#     score_fn = target_gmm.score
    
#     # Create methods to evaluate
#     methods = {}
#     particles_dict = {}
#     convergence_dict = {}
#     results_dict = {}
    
#     # Create target info for improved methods
#     target_info = {
#         'n_modes': len(target_gmm.means),
#         'centers': target_gmm.means,
#         'covs': [cov for cov in target_gmm.covs]
#     }
    
#     # Add methods based on what's requested
#     if 'ESCORT20D' in methods_to_run:
#         methods['ESCORT20D'] = ESCORT20DAdapter(
#             n_iter=n_iter, step_size=step_size, verbose=verbose, target_info=target_info)
    
#     if 'SVGD' in methods_to_run:
#         methods['SVGD'] = StableSVGD20DAdapter(
#             n_iter=n_iter, 
#             step_size=step_size, 
#             verbose=verbose,
#             target_info=target_info
#         )
    
#     if 'DVRL' in methods_to_run:
#         try:
#             # Initialize the DVRL model
#             dvrl = DVRL(
#                 obs_dim=20,         # 20D state space
#                 action_dim=1,       # Simple 1D actions for testing
#                 h_dim=128,          # Larger hidden state dimension for 20D
#                 z_dim=20,           # Latent state dimension (matches state dimension)
#                 n_particles=100,    # Use fewer particles for stability
#                 continuous_actions=True
#             )
            
#             # Explicitly move model to CPU 
#             dvrl = dvrl.to(torch.device('cpu'))
            
#             # Create the adapter with the fixed implementation
#             methods['DVRL'] = DVRLAdapter20D(dvrl, n_samples=n_particles)
#         except Exception as e:
#             print(f"Error initializing DVRL: {e}")
#             # Create a fallback that returns initial particles
#             methods['DVRL'] = lambda initial_particles, score_fn, target_samples=None, return_convergence=False: (
#                 (initial_particles.copy(), {"iterations": 0}) if return_convergence else initial_particles.copy()
#             )
    
#     if 'SIR' in methods_to_run:
#         methods['SIR'] = SIRAdapter(n_iter=1)  # Just one iteration for SIR
    
#     # Run each method
#     for method_name, method in methods.items():
#         if method is None:
#             continue
            
#         print(f"Running {method_name}...")
        
#         try:
#             start_time = time.time()
            
#             # Special handling for lambda fallback if used
#             if callable(method) and not hasattr(method, 'fit_transform'):
#                 # This is our lambda fallback for DVRL
#                 particles, convergence = method(
#                     initial_particles.copy(), score_fn, target_samples, return_convergence=True)
#             else:
#                 # Normal method call
#                 particles, convergence = method.fit_transform(
#                     initial_particles.copy(), score_fn, target_samples, return_convergence=True)
                
#             end_time = time.time()
            
#             # Store results
#             particles_dict[method_name] = particles
#             convergence_dict[method_name] = convergence
            
#             # Evaluate the method
#             evaluation = evaluate_method_20d(
#                 method_name, particles, target_gmm, target_samples, 
#                 runtime=end_time - start_time)
            
#             # Store evaluation results
#             results_dict[method_name] = evaluation
            
#             print(f"{method_name} completed in {end_time - start_time:.2f} seconds")
#         except Exception as e:
#             print(f"Error in {method_name}: {e}")
#             traceback.print_exc()
            
#             # Create fallback results for this method
#             particles = initial_particles.copy() + np.random.randn(*initial_particles.shape) * 0.1
#             particles_dict[method_name] = particles
#             convergence_dict[method_name] = {"iterations": 0}
            
#             # Still evaluate with the fallback particles
#             evaluation = evaluate_method_20d(
#                 method_name, particles, target_gmm, target_samples, 
#                 runtime=0.0)  # Use 0 runtime since this is a fallback
            
#             # Store evaluation results
#             results_dict[method_name] = evaluation
    
#     # Create results DataFrame
#     results_df = pd.DataFrame.from_dict(results_dict, orient='index')
    
#     # Display results
#     print("\nResults Summary:")
#     print(results_df)
    
#     # Generate visualizations
#     try:
#         visualize_results_20d(target_gmm, target_samples, particles_dict, 
#                             convergence_dict, results_df)
#     except Exception as e:
#         print(f"Error in visualization: {e}")
#         traceback.print_exc()
    
#     return results_df, target_gmm, particles_dict, convergence_dict


# ========================================
# Main Execution
# ========================================

if __name__ == "__main__":
    import argparse
    
    # Set up argument parser
    parser = argparse.ArgumentParser(description='ESCORT 20D Framework Evaluation with Multiple Seeds')
    parser.add_argument('--methods', nargs='+', 
                    default=['ESCORT20D', 'SVGD', 'DVRL', 'SIR'],
                    help='Methods to evaluate (default: ESCORT20D SVGD DVRL SIR)')
    parser.add_argument('--n_runs', type=int, default=5,
                    help='Number of runs with different initializations (default: 5)')
    parser.add_argument('--n_iter', type=int, default=300, 
                    help='Number of iterations (default: 300)')
    parser.add_argument('--step_size', type=float, default=0.01,
                    help='Step size for updates (default: 0.01)')
    parser.add_argument('--no_verbose', action='store_false', dest='verbose',
                    help='Disable verbose output (default: verbose enabled)')
    parser.add_argument('--fixed_seeds', action='store_true',
                    help='Use fixed seeds instead of random ones (default: False)')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Configure parameters
    method_params = {
        'n_iter': args.n_iter,
        'step_size': args.step_size,
        'verbose': args.verbose,
    }
    
    # Use fixed seeds if requested
    if args.fixed_seeds:
        seeds = [42, 123, 456, 789, 101]  # Fixed seeds for reproducibility
        seeds = seeds[:args.n_runs]  # Truncate if fewer runs requested
    else:
        # Generate random master seed
        master_seed = np.random.randint(0, 10000)
        print(f"Master seed: {master_seed}")
        
        # Use master seed to generate seeds for individual runs
        np.random.seed(master_seed)
        seeds = np.random.randint(0, 10000, size=args.n_runs)
    
    # Print seeds being used
    print(f"Using seeds: {seeds}")
    
    # Run the experiment with multiple seeds
    mean_results_df, all_results_df, all_particles, all_convergence, target_gmm = run_experiment_with_multiple_seeds(
        methods_to_run=args.methods,
        n_runs=args.n_runs,
        seeds=seeds,
        **method_params
    )
    
    # Save results to CSV
    results_dir = os.path.join(SCRIPT_DIR, "results_20d_multiseed")
    os.makedirs(results_dir, exist_ok=True)
    mean_results_df.to_csv(os.path.join(results_dir, "escort_20d_mean_results.csv"))
    all_results_df.to_csv(os.path.join(results_dir, "escort_20d_all_results.csv"))
    
    # Generate visualizations
    visualize_results_with_error_bars(
        mean_results_df, all_results_df, all_particles, all_convergence, target_gmm
    )
    
    print("\nExperiment complete. Results saved to CSV and visualizations saved to plots_20d_multiseed directory.")
    print(f"Results saved in: {results_dir}")
