"""
    ESCORT Framework Evaluation on 2D Multi-Modal Correlated GMM
    
    This script tests the ESCORT framework against a challenging 2D GMM distribution
    with multiple modes having different correlation structures. The experiment is designed 
    to evaluate how well different methods handle the key challenges of POMDPs:
    1. High dimensionality 
    2. Multi-modality
    3. Dimensional correlation
    
    Fixed version with improved numerical stability and error handling.
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde, wasserstein_distance
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.preprocessing import StandardScaler
import pandas as pd
import time
import os
import argparse
import torch
import torch.optim as optim
from scipy.linalg import sqrtm
import traceback

from dvrl.dvrl import DVRL
from belief_assessment.distributions import GMMDistribution
from escort.svgd import SVGD, AdaptiveSVGD
from belief_assessment.evaluation.visualize_distributions import GMMVisualizer

# Set random seed for reproducibility
np.random.seed(42)

# Get the directory of the current script for saving outputs
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))


class DVRLAdapter:
    """Adapter for DVRL to match interface with other methods"""
    def __init__(self, dvrl_model, n_samples=1000):
        self.dvrl_model = dvrl_model
        self.n_samples = n_samples
        
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        try:
            # Explicitly move model to CPU to avoid device mismatches
            self.dvrl_model = self.dvrl_model.to('cpu')
            
            # Initialize with CPU device
            device = torch.device('cpu')
            
            # Convert initial particles to tensor
            initial_particles_tensor = torch.tensor(initial_particles, dtype=torch.float32, device=device)
            
            # Initialize DVRL belief state
            batch_size = 1  # For testing
            h_particles, z_particles, weights = self.dvrl_model.init_belief(batch_size, device)
            
            # Ensure all tensors are on the same device
            h_particles = h_particles.to(device)
            z_particles = z_particles.to(device)
            weights = weights.to(device)
            
            # Use DVRL to generate particles
            dummy_action = torch.zeros(batch_size, self.dvrl_model.action_dim, device=device)
            dummy_obs = torch.zeros(batch_size, self.dvrl_model.obs_dim, device=device)
            
            # Update belief a few times to stabilize (with error handling)
            try:
                for _ in range(3):  # Reduced from 5 to 3 iterations for less chance of error
                    # Ensure all inputs are on the same device before calling update_belief
                    h_particles = h_particles.to(device)
                    z_particles = z_particles.to(device)
                    weights = weights.to(device)
                    dummy_action = dummy_action.to(device)
                    dummy_obs = dummy_obs.to(device)
                    
                    # Update the belief state
                    h_particles, z_particles, weights, _ = self.dvrl_model.update_belief(
                        h_particles, z_particles, weights, dummy_action, dummy_obs
                    )
            except Exception as e:
                print(f"Warning: Error during belief update: {e}")
                # Continue with current particles instead of failing
            
            # Extract z particles as our state representation 
            # (use safeguard in case previous steps failed)
            try:
                particles = z_particles.squeeze(0).cpu().numpy()
            except:
                # Fallback to random particles similar to initial distribution
                print("Warning: Using fallback particle generation")
                particles = np.random.randn(self.dvrl_model.n_particles, 2) * np.std(initial_particles, axis=0)
            
            # Ensure we have the right number of particles
            if particles.shape[0] < self.n_samples:
                # Duplicate particles if needed
                particles = np.repeat(particles, (self.n_samples // particles.shape[0]) + 1, axis=0)
                particles = particles[:self.n_samples]
            elif particles.shape[0] > self.n_samples:
                # Sample particles if too many
                indices = np.random.choice(particles.shape[0], self.n_samples, replace=False)
                particles = particles[indices]
            
            if return_convergence:
                # Return particles and minimal convergence info
                return particles, {"iterations": 1}
            else:
                return particles
                
        except Exception as e:
            print(f"Error in DVRL processing: {e}")
            import traceback
            traceback.print_exc()
            
            # Return initial particles as fallback
            if return_convergence:
                return initial_particles, {"iterations": 0}
            else:
                return initial_particles


class SIRAdapter:
    """
    # & Adapter to use Sequential Importance Resampling (SIR) Particle Filter
    # & for 2D distribution approximation
    """
    def __init__(self, n_particles=1000, n_iter=300, resample_threshold=0.5, verbose=True):
        # & Initialize parameters
        self.n_particles = n_particles
        self.n_iter = n_iter
        self.resample_threshold = resample_threshold  # Resample when ESS < threshold * n_particles
        self.verbose = verbose
        
        # & Track convergence
        self.convergence = {'delta_norm_history': []}
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Run SIR particle filter to approximate the target distribution
        """
        # & Use provided initial particles
        particles = initial_particles.copy()
        n_particles = len(particles)
        
        # & Initialize uniform weights
        weights = np.ones(n_particles) / n_particles
        
        # & Main SIR loop
        prev_particles = particles.copy()
        for i in range(self.n_iter):
            try:
                # & Get log probabilities for current particles - FIX HERE
                try:
                    # Try with return_logp parameter
                    _, log_weights = score_fn(particles, return_logp=True)
                except TypeError:
                    # If that fails, assume score_fn returns the log probabilities directly
                    log_weights = score_fn(particles)
                
                # & Check for NaN/Inf in log weights
                if np.any(np.isnan(log_weights)) or np.any(np.isinf(log_weights)):
                    print(f"Warning: NaN or Inf in log weights at iteration {i}")
                    # Replace problematic values
                    log_weights = np.nan_to_num(log_weights, nan=-1e10, posinf=-1e10, neginf=-1e10)
                
                # & Update weights (using importance sampling)
                max_log_weight = np.max(log_weights)
                weights = np.exp(log_weights - max_log_weight)
                
                # & Check for zero weights
                if np.sum(weights) == 0:
                    print(f"Warning: All zero weights at iteration {i}, using uniform weights")
                    weights = np.ones(n_particles) / n_particles
                else:
                    weights = weights / np.sum(weights)
                
                # & Calculate effective sample size
                ess = 1.0 / np.sum(weights**2)
                normalized_ess = ess / n_particles
                
                # & Resample if ESS is too low
                if normalized_ess < self.resample_threshold:
                    if self.verbose and i % 20 == 0:
                        print(f"Iteration {i}: Resampling (ESS = {normalized_ess:.4f})")
                    
                    # & Systematic resampling
                    indices = self._systematic_resample(weights)
                    particles = particles[indices]
                    
                    # & Reset weights to uniform
                    weights = np.ones(n_particles) / n_particles
                    
                    # & Add small noise to avoid particle collapse
                    noise_scale = 0.1 * (1.0 - i/self.n_iter)  # Gradually reduce noise
                    particles += np.random.normal(0, noise_scale, particles.shape)
                
                # & Store delta norm for convergence tracking
                delta_norm = np.linalg.norm(particles - prev_particles) / n_particles
                self.convergence['delta_norm_history'].append(delta_norm)
                
                prev_particles = particles.copy()
                
            except Exception as e:
                print(f"Error in SIR iteration {i}: {e}")
                # Continue with previous particles
        
        if return_convergence:
            return particles, self.convergence
        return particles
    
    def _systematic_resample(self, weights):
        """
        # & Perform systematic resampling
        """
        n = len(weights)
        positions = (np.random.random() + np.arange(n)) / n
        
        indices = np.zeros(n, dtype=int)
        cumulative_sum = np.cumsum(weights)
        i, j = 0, 0
        
        while i < n:
            if positions[i] < cumulative_sum[j]:
                indices[i] = j
                i += 1
            else:
                j += 1
                if j >= n:  # Handle edge case
                    indices[i:] = n-1
                    break
        
        return indices


# ================================
# Define evaluation metrics for 2D distributions with robust error handling
# ================================

def compute_mmd(X, Y, gamma=0.5, max_value=10.0):
    """
    Compute Maximum Mean Discrepancy between 2D samples X and Y.
    
    Args:
        X (np.ndarray): First sample set with shape (n_samples, 2)
        Y (np.ndarray): Second sample set with shape (n_samples, 2)
        gamma (float): RBF kernel parameter
        max_value (float): Maximum return value for numerical stability
        
    Returns:
        float: MMD value
    """
    try:
        # Check inputs
        if np.any(np.isnan(X)) or np.any(np.isinf(X)) or np.any(np.isnan(Y)) or np.any(np.isinf(Y)):
            print("Warning: NaN or Inf values detected in MMD inputs")
            # Clean inputs
            X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
            Y = np.nan_to_num(Y, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # Standardize inputs to improve numerical stability
        X_std = StandardScaler().fit_transform(X)
        Y_std = StandardScaler().fit_transform(Y)
        
        # Compute kernel values
        XX = rbf_kernel(X_std, X_std, gamma)
        YY = rbf_kernel(Y_std, Y_std, gamma)
        XY = rbf_kernel(X_std, Y_std, gamma)
        
        # Calculate MMD
        mmd = np.mean(XX) + np.mean(YY) - 2 * np.mean(XY)
        
        # Ensure reasonable value
        if np.isnan(mmd) or np.isinf(mmd) or mmd > max_value:
            print(f"Warning: Unstable MMD value: {mmd}, capping at {max_value}")
            return max_value
            
        return mmd
    except Exception as e:
        print(f"Error in MMD computation: {e}")
        return max_value

def estimate_kl_divergence_2d(p_samples, q_samples, n_bins=20, max_value=20.0):
    """
    Estimate KL(P||Q) from samples of P and Q using 2D histograms.
    
    Args:
        p_samples (np.ndarray): Samples from distribution P with shape (n_samples, 2)
        q_samples (np.ndarray): Samples from distribution Q with shape (n_samples, 2)
        n_bins (int): Number of bins per dimension for histogram approximation
        max_value (float): Maximum return value for numerical stability
        
    Returns:
        float: Estimated KL divergence
    """
    try:
        # Check for valid inputs
        if len(p_samples) < 10 or len(q_samples) < 10:
            print("Warning: Too few samples for KL divergence estimation")
            return max_value
            
        # Check for NaN/Inf values
        if np.any(np.isnan(p_samples)) or np.any(np.isinf(p_samples)) or \
           np.any(np.isnan(q_samples)) or np.any(np.isinf(q_samples)):
            print("Warning: NaN or Inf values detected in KL inputs")
            # Clean inputs
            p_samples = np.nan_to_num(p_samples, nan=0.0, posinf=1.0, neginf=-1.0)
            q_samples = np.nan_to_num(q_samples, nan=0.0, posinf=1.0, neginf=-1.0)
            
        # Determine range for histograms based on both sample sets
        min_x = min(np.percentile(p_samples[:, 0], 1), np.percentile(q_samples[:, 0], 1))
        max_x = max(np.percentile(p_samples[:, 0], 99), np.percentile(q_samples[:, 0], 99))
        min_y = min(np.percentile(p_samples[:, 1], 1), np.percentile(q_samples[:, 1], 1))
        max_y = max(np.percentile(p_samples[:, 1], 99), np.percentile(q_samples[:, 1], 99))
        
        # Add a small margin to ensure all points are included
        margin_x = 0.1 * (max_x - min_x)
        margin_y = 0.1 * (max_y - min_y)
        
        # Create bins
        x_bins = np.linspace(min_x - margin_x, max_x + margin_x, n_bins + 1)
        y_bins = np.linspace(min_y - margin_y, max_y + margin_y, n_bins + 1)
        
        # Compute 2D histograms
        p_hist, x_edges, y_edges = np.histogram2d(p_samples[:, 0], p_samples[:, 1], 
                                              bins=[x_bins, y_bins], density=True)
        q_hist, _, _ = np.histogram2d(q_samples[:, 0], q_samples[:, 1], 
                                  bins=[x_bins, y_bins], density=True)
        
        # Add small epsilon to avoid division by zero or log of zero
        epsilon = 1e-10
        p_hist = p_hist + epsilon
        q_hist = q_hist + epsilon
        
        # Normalize
        p_hist = p_hist / np.sum(p_hist)
        q_hist = q_hist / np.sum(q_hist)
        
        # Compute KL divergence
        # KL(P||Q) = sum_i p_i * log(p_i / q_i)
        kl = np.sum(p_hist * np.log(p_hist / q_hist))
        
        # Check for numerical stability
        if np.isnan(kl) or np.isinf(kl) or kl > max_value:
            print(f"Warning: Unstable KL divergence value: {kl}, capping at {max_value}")
            return max_value
            
        return kl
    except Exception as e:
        print(f"Error in KL divergence computation: {e}")
        return max_value

# might also need to fix the compute_mode_coverage_2d function:
def compute_mode_coverage_2d(particles, mode_means, mode_covs, threshold=3.0):
    """
    Compute how well particles cover all modes in a multi-modal distribution.
    
    Args:
        particles (np.ndarray): Particles with shape (n_particles, 2)
        mode_means (np.ndarray): Means of the modes with shape (n_modes, 2)
        mode_covs (list or np.ndarray): Covariance matrices for each mode
        threshold (float): Mahalanobis distance threshold to consider a mode covered
    
    Returns:
        float: Percentage of modes covered by particles (0.0 to 1.0)
    """
    try:
        n_modes = len(mode_means)
        modes_covered = np.zeros(n_modes, dtype=bool)
        
        for i in range(n_modes):
            # Get mode parameters
            mean = mode_means[i]
            cov = mode_covs[i]
            
            # Compute distances from this mode center to all particles
            diff = particles - mean
            
            # For Mahalanobis distance, need to ensure covariance matrix is invertible
            try:
                # Add small regularization if needed
                cov_reg = cov + 1e-6 * np.eye(cov.shape[0])
                cov_inv = np.linalg.inv(cov_reg)
                
                # Compute squared Mahalanobis distances
                mahalanobis_sq = np.sum(diff @ cov_inv * diff, axis=1)
                
                # Check if any particle is close enough to this mode
                # Fix for truth value ambiguity - use numpy's any() method
                if np.any(mahalanobis_sq < threshold**2):
                    modes_covered[i] = True
            except np.linalg.LinAlgError:
                # If inverting covariance fails, use Euclidean distance as fallback
                euclidean_sq = np.sum(diff**2, axis=1)
                # Fix for truth value ambiguity - use numpy's any() method
                if np.any(euclidean_sq < threshold**2):
                    modes_covered[i] = True
        
        # Return percentage of modes covered
        return np.mean(modes_covered)
    except Exception as e:
        print(f"Error in mode coverage computation: {e}")
        return 0.0  # Default: no modes covered

def compute_correlation_error(target_samples, approx_samples, mode_means, threshold=2.0, default_error=20.0):
    """
    Compute the error in capturing the correlation structure of each mode.
    
    Args:
        target_samples (np.ndarray): Samples from target distribution with shape (n_samples, 2)
        approx_samples (np.ndarray): Samples from approximating distribution with shape (n_samples, 2)
        mode_means (np.ndarray): Means of the target modes with shape (n_modes, 2)
        threshold (float): Distance threshold to assign samples to modes
        default_error (float): Default error value when computation fails
        
    Returns:
        float: Average correlation structure error across modes
    """
    try:
        # Check for valid inputs
        if len(target_samples) == 0 or len(approx_samples) == 0:
            print("Warning: Empty sample set for correlation error computation")
            return default_error
            
        # Check for NaN/Inf values
        if np.any(np.isnan(target_samples)) or np.any(np.isinf(target_samples)) or \
           np.any(np.isnan(approx_samples)) or np.any(np.isinf(approx_samples)):
            print("Warning: NaN or Inf values detected in correlation error inputs")
            # Clean inputs
            target_samples = np.nan_to_num(target_samples, nan=0.0, posinf=1.0, neginf=-1.0)
            approx_samples = np.nan_to_num(approx_samples, nan=0.0, posinf=1.0, neginf=-1.0)
            
        n_modes = mode_means.shape[0]
        correlation_errors = np.zeros(n_modes)
        
        # Assign target samples to modes
        target_mode_samples = [[] for _ in range(n_modes)]
        for sample in target_samples:
            distances = np.sqrt(np.sum((sample - mode_means)**2, axis=1))
            closest_mode = np.argmin(distances)
            if distances[closest_mode] < threshold:
                target_mode_samples[closest_mode].append(sample)
        
        # Assign approximation samples to modes
        approx_mode_samples = [[] for _ in range(n_modes)]
        for sample in approx_samples:
            distances = np.sqrt(np.sum((sample - mode_means)**2, axis=1))
            closest_mode = np.argmin(distances)
            if distances[closest_mode] < threshold:
                approx_mode_samples[closest_mode].append(sample)
        
        # Compute correlation error for each mode
        valid_modes = 0
        total_error = 0.0
        
        for i in range(n_modes):
            # Need enough samples to estimate correlation structure
            if len(target_mode_samples[i]) > 10 and len(approx_mode_samples[i]) > 10:
                # Convert to numpy arrays
                target_mode_array = np.array(target_mode_samples[i])
                approx_mode_array = np.array(approx_mode_samples[i])
                
                # Compute covariance matrices
                try:
                    target_cov = np.cov(target_mode_array.T)
                    approx_cov = np.cov(approx_mode_array.T)
                    
                    # Calculate Frobenius norm of difference as error measure
                    error = np.linalg.norm(target_cov - approx_cov, 'fro')
                    
                    # Check for numerical stability
                    if np.isnan(error) or np.isinf(error) or error > default_error:
                        error = default_error
                        
                    correlation_errors[i] = error
                    valid_modes += 1
                    total_error += error
                except Exception as e:
                    print(f"Error computing covariance for mode {i}: {e}")
                    correlation_errors[i] = default_error
                    valid_modes += 1
                    total_error += default_error
            else:
                # Not enough samples for this mode, use default error
                correlation_errors[i] = default_error
                # Only count in average if some samples were found
                if len(approx_mode_samples[i]) > 0:
                    valid_modes += 1
                    total_error += default_error
        
        # Return average error across valid modes, or default if no valid modes
        if valid_modes > 0:
            return total_error / valid_modes
        else:
            return default_error
    except Exception as e:
        print(f"Error in correlation error computation: {e}")
        return default_error

# Fix for compute_ess function
def compute_ess(particles, score_fn):
    """
    Compute the Effective Sample Size (ESS) of a particle set.
    
    Args:
        particles (np.ndarray): Particle set with shape (n_particles, dim)
        score_fn (callable): Function to score particles (can be either a function or an object with log_prob)
        
    Returns:
        float: Normalized ESS (0.0 to 1.0)
    """
    try:
        n_particles = len(particles)
        
        # Get log probabilities
        try:
            # First try: see if score_fn has a log_prob method
            if hasattr(score_fn, 'log_prob'):
                log_probs = score_fn.log_prob(particles)
            # Second try: see if score_fn returns log probs when called directly
            else:
                # Try with return_logp parameter first
                try:
                    _, log_probs = score_fn(particles, return_logp=True)
                except (TypeError, ValueError):
                    # If that fails, just use the function directly
                    log_probs = score_fn(particles)
        except Exception as e:
            print(f"Error computing log probabilities for ESS: {e}")
            return 0.0  # Return minimum ESS
        
        # Avoid overflow/underflow
        log_probs = np.array(log_probs)
        log_probs = np.nan_to_num(log_probs, nan=-1e10, posinf=-1e10, neginf=-1e10)
        max_log_prob = np.max(log_probs)
        
        # Compute normalized weights
        weights = np.exp(log_probs - max_log_prob)
        if np.sum(weights) <= 0:
            return 0.0  # All weights are zero or negative
        weights = weights / np.sum(weights)
        
        # Compute ESS
        ess = 1.0 / np.sum(weights**2)
        
        # Normalize by number of particles
        normalized_ess = ess / n_particles
        
        # Clip to valid range
        normalized_ess = np.clip(normalized_ess, 0.0, 1.0)
        
        return normalized_ess
    except Exception as e:
        print(f"Error in ESS computation: {e}")
        return 0.0  # Return minimum ESS

def compute_sliced_wasserstein_distance(X, Y, n_projections=50, max_value=20.0):
    """
    Compute the Sliced Wasserstein Distance between two 2D distributions.
    
    Args:
        X (np.ndarray): First sample set with shape (n_samples, 2)
        Y (np.ndarray): Second sample set with shape (n_samples, 2)
        n_projections (int): Number of random projections
        max_value (float): Maximum return value for numerical stability
        
    Returns:
        float: Sliced Wasserstein Distance
    """
    try:
        # Check inputs
        if len(X) == 0 or len(Y) == 0:
            print("Warning: Empty sample set for SWD computation")
            return max_value
            
        # Check for NaN/Inf values
        if np.any(np.isnan(X)) or np.any(np.isinf(X)) or np.any(np.isnan(Y)) or np.any(np.isinf(Y)):
            print("Warning: NaN or Inf values detected in SWD inputs")
            # Clean inputs
            X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
            Y = np.nan_to_num(Y, nan=0.0, posinf=1.0, neginf=-1.0)
            
        # Get sample sizes
        n_samples_X = X.shape[0]
        n_samples_Y = Y.shape[0]
        
        # Generate random projection directions
        theta = np.random.normal(0, 1, (n_projections, 2))
        # Normalize the directions
        theta = theta / np.sqrt(np.sum(theta**2, axis=1, keepdims=True))
        
        # Project the samples onto each direction
        X_projections = np.dot(X, theta.T)
        Y_projections = np.dot(Y, theta.T)
        
        # Sort the projections
        X_projections.sort(axis=0)
        Y_projections.sort(axis=0)
        
        # Ensure equal number of samples by interpolation if necessary
        if n_samples_X != n_samples_Y:
            if n_samples_X > n_samples_Y:
                # Interpolate Y to match X's size
                old_indices = np.linspace(0, n_samples_Y - 1, n_samples_Y)
                new_indices = np.linspace(0, n_samples_Y - 1, n_samples_X)
                Y_projections = np.array([np.interp(new_indices, old_indices, Y_projections[:, j]) 
                                       for j in range(n_projections)]).T
            else:
                # Interpolate X to match Y's size
                old_indices = np.linspace(0, n_samples_X - 1, n_samples_X)
                new_indices = np.linspace(0, n_samples_X - 1, n_samples_Y)
                X_projections = np.array([np.interp(new_indices, old_indices, X_projections[:, j])
                                       for j in range(n_projections)]).T
        
        # Compute 1D Wasserstein distances along each projection
        wasserstein_distances = np.mean(np.abs(X_projections - Y_projections), axis=0)
        
        # Average over all projections
        swd = np.mean(wasserstein_distances)
        
        # Check for numerical stability
        if np.isnan(swd) or np.isinf(swd) or swd > max_value:
            print(f"Warning: Unstable SWD value: {swd}, capping at {max_value}")
            return max_value
            
        return swd
    except Exception as e:
        print(f"Error in SWD computation: {e}")
        return max_value

def evaluate_method(method_name, particles, target_distribution, target_samples, runtime=None):
    """
    Evaluate method performance using various metrics
    
    Args:
        method_name (str): Name of the method
        particles (np.ndarray): Particles from the method
        target_distribution: Target distribution object
        target_samples (np.ndarray): Samples from target distribution
        runtime (float, optional): Runtime in seconds
        
    Returns:
        dict: Dictionary with evaluation metrics
    """
    results = {}
    
    # Calculate metrics using imported functions
    try:
        # MMD
        try:
            results['MMD'] = compute_mmd(particles, target_samples)
        except Exception as e:
            print(f"Error in MMD computation: {e}")
            results['MMD'] = np.nan
        
        # KL divergence
        try:
            results['KL(Target||Method)'] = estimate_kl_divergence_2d(target_samples, particles)
            results['KL(Method||Target)'] = estimate_kl_divergence_2d(particles, target_samples)
        except Exception as e:
            print(f"Error in KL divergence computation: {e}")
            results['KL(Target||Method)'] = np.nan
            results['KL(Method||Target)'] = np.nan
        
        # Mode coverage
        try:
            results['Mode Coverage'] = compute_mode_coverage_2d(particles, target_distribution.means, 
                                                        target_distribution.covs)
        except Exception as e:
            print(f"Error in mode coverage computation: {e}")
            results['Mode Coverage'] = np.nan
        
        # Correlation error
        try:
            results['Correlation Error'] = compute_correlation_error(
                target_samples, particles, target_distribution.means)
        except Exception as e:
            print(f"Error in correlation computation: {e}")
            results['Correlation Error'] = np.nan
        
        # ESS - use fixed compute_ess function
        try:
            results['ESS'] = compute_ess(particles, target_distribution.score)
        except Exception as e:
            print(f"Error in ESS computation: {e}")
            results['ESS'] = np.nan
        
        # Sliced Wasserstein
        try:
            results['Sliced Wasserstein'] = compute_sliced_wasserstein_distance(particles, target_samples)
        except Exception as e:
            print(f"Error in Sliced Wasserstein computation: {e}")
            results['Sliced Wasserstein'] = np.nan
        
        # Runtime
        if runtime is not None:
            results['Runtime (s)'] = runtime
    except Exception as e:
        print(f"Error computing metrics for {method_name}: {e}")
        
        # Default values for metrics
        results['MMD'] = np.nan
        results['KL(Target||Method)'] = np.nan
        results['KL(Method||Target)'] = np.nan
        results['Mode Coverage'] = np.nan
        results['Correlation Error'] = np.nan
        results['ESS'] = np.nan
        results['Sliced Wasserstein'] = np.nan
        if runtime is not None:
            results['Runtime (s)'] = runtime
    
    return results

# ================================
# Define the target 2D distribution
# ================================

def create_target_distribution():
    """
    Create a challenging 2D GMM with three modes having different correlation structures.
    
    Returns:
        GMMDistribution: Target 2D distribution
    """
    # Define the means of the GMM components
    means = np.array([
        [-3.0, -3.0],  # Bottom-left
        [0.0, 0.0],    # Center
        [3.0, 3.0]     # Top-right
    ])
    
    # Define the covariance matrices for the components
    # 1. Bottom-left component: Positive correlation
    # 2. Center component: Circular (no correlation)
    # 3. Top-right component: Negative correlation
    covs = np.array([
        # Bottom-left: Positive correlation
        [[1.0, 0.8],
         [0.8, 1.0]],
        
        # Center: No correlation (circular)
        [[0.5, 0.0],
         [0.0, 0.5]],
        
        # Top-right: Negative correlation
        [[1.0, -0.8],
         [-0.8, 1.0]]
    ])
    
    # Define the weights of the components (sum to 1)
    weights = np.array([0.3, 0.4, 0.3])  # Uneven weights
    
    # Create the GMM distribution
    return GMMDistribution(means, covs, weights, name="Target 2D Correlated GMM")


# ================================
# Method factory with improved stability parameters
# ================================

def get_method(method_name, **kwargs):
    """
    Factory function to create method instances based on name.
    
    Args:
        method_name (str): Name of the method to create
        **kwargs: Additional arguments for method configuration
        
    Returns:
        object: Method instance
    """
    if method_name.upper() == "ESCORT":
        return AdaptiveSVGD(
            step_size=kwargs.get('step_size', 0.01),  # Reduced from 0.05 to 0.01
            n_iter=kwargs.get('n_iter', 300),
            lambda_reg=kwargs.get('lambda_reg', 0.005),  # Reduced from 0.01 for stability
            dynamic_bandwidth=kwargs.get('dynamic_bandwidth', True),
            enhanced_repulsion=kwargs.get('enhanced_repulsion', True),
            noise_level=kwargs.get('noise_level', 0.1),  # Increased from 0.05
            noise_decay=kwargs.get('noise_decay', 0.98),  # Slower decay
            mode_balancing=kwargs.get('mode_balancing', True),
            adaptive_lambda=kwargs.get('adaptive_lambda', True),
            aggressive_exploration=True,  # Enable aggressive exploration for better mode discovery
            verbose=kwargs.get('verbose', True)
        )
    elif method_name.upper() == "SVGD":
        return SVGD(
            step_size=kwargs.get('step_size', 0.01),  # Reduced from 0.05
            n_iter=kwargs.get('n_iter', 300),
            lambda_reg=kwargs.get('lambda_reg', 0.0),
            dynamic_bandwidth=kwargs.get('dynamic_bandwidth', True),  # Enabled
            enhanced_repulsion=kwargs.get('enhanced_repulsion', True),  # Enabled
            noise_level=kwargs.get('noise_level', 0.05),  # Added small noise
            noise_decay=kwargs.get('noise_decay', 0.98),  # Slower decay
            mode_balancing=kwargs.get('mode_balancing', False),
            verbose=kwargs.get('verbose', True)
        )
    elif method_name.upper() == "DVRL":
        return DVRLAdapter(
            obs_dim=2,  # 2D for this experiment
            action_dim=1,
            h_dim=kwargs.get('h_dim', 64),
            z_dim=2,     # Match 2D dimensionality
            n_particles=kwargs.get('n_particles', 30),
            learning_rate=kwargs.get('learning_rate', 1e-4),
            n_iter=kwargs.get('n_iter', 300)
        )
    elif method_name.upper() == "SIR":
        return SIRAdapter(
            n_particles=kwargs.get('n_particles', 1000),
            n_iter=kwargs.get('n_iter', 300),
            resample_threshold=kwargs.get('resample_threshold', 0.5),
            verbose=kwargs.get('verbose', True)
        )
    else:
        raise ValueError(f"Unknown method: {method_name}")

def run_experiment(methods_to_run=None, **kwargs):
    """
    Run the full 2D evaluation experiment.
    
    Args:
        methods_to_run (list): List of method names to run. If None, runs all methods.
        **kwargs: Additional arguments for method configuration
    """
    print("Starting 2D GMM evaluation experiment...")
    
    # Set default methods if not specified
    if methods_to_run is None:
        methods_to_run = ["ESCORT", "SVGD", "DVRL", "SIR"]
    
    # Create target distribution
    target_gmm = create_target_distribution()
    
    # Extract mode means
    mode_means = target_gmm.means
    
    # Generate target samples for evaluation
    n_eval_samples = 2000
    target_samples = target_gmm.sample(n_eval_samples)
    
    # Define score function for particle updates
    def score_fn(x, return_logp=False):
        """Score function (gradient of log density) for the target GMM."""
        try:
            scores = target_gmm.score(x)
            
            # Check for numerical stability
            if np.any(np.isnan(scores)) or np.any(np.isinf(scores)):
                print("Warning: NaN or Inf in score function output")
                # Replace problematic values with zeros
                scores = np.nan_to_num(scores, nan=0.0, posinf=0.0, neginf=0.0)
                
            if return_logp:
                log_probs = target_gmm.log_prob(x)
                
                # Check for numerical stability
                if np.any(np.isnan(log_probs)) or np.any(np.isinf(log_probs)):
                    print("Warning: NaN or Inf in log probabilities")
                    # Replace problematic values
                    log_probs = np.nan_to_num(log_probs, nan=-1e10, posinf=-1e10, neginf=-1e10)
                    
                return scores, log_probs
            return scores
        except Exception as e:
            print(f"Error in score function: {e}")
            if return_logp:
                return np.zeros_like(x), np.ones(len(x)) * -1e10
            return np.zeros_like(x)
    
    # Generate initial particles from a simple distribution
    n_particles = 1000
    initial_particles = np.random.randn(n_particles, 2) * 2.0
    
    # Dictionary to store results for each method
    results = []
    global methods_runtime
    methods_runtime = {}
    
    # Dictionary to store particles and convergence info
    particles_dict = {}
    convergence_dict = {}
    
    # Run each requested method
    for method_name in methods_to_run:
        print(f"\nRunning {method_name}...")
        
        # Create method instance
        method = get_method(method_name, **kwargs)
        
        # Time the execution
        start_time = time.time()
        try:
            if hasattr(method, 'fit_transform') and callable(method.fit_transform):
                particles = method.fit_transform(
                    initial_particles.copy(), 
                    score_fn, 
                    target_samples=target_samples,
                    return_convergence=False
                )
                
                # Get convergence info if available
                if method_name.upper() == "ESCORT":
                    try:
                        _, convergence = method.fit_transform(
                            initial_particles.copy(), 
                            score_fn, 
                            target_samples=target_samples,
                            return_convergence=True
                        )
                        convergence_dict[method_name] = convergence
                    except Exception as e:
                        print(f"Error getting convergence info for {method_name}: {e}")
            else:
                print(f"Warning: {method_name} doesn't have fit_transform method")
                particles = np.random.randn(100, 2)  # Fallback
                
        except Exception as e:
            print(f"Error running {method_name}: {e}")
            traceback.print_exc()  # Print full traceback
            # Create fallback particles
            particles = np.random.randn(100, 2)
            
        methods_runtime[method_name] = time.time() - start_time
        particles_dict[method_name] = particles
        
        # Evaluate and store results
        method_results = evaluate_method(
            method_name, particles, target_samples, 
            target_gmm, mode_means
        )
        results.append(method_results)
    
    # Create results table
    results_df = pd.DataFrame(results)
    results_df = results_df.set_index('Method')
    
    # Print results
    print("\nResults Summary:")
    print(results_df)
    
    # Visualize results
    visualize_results(
        target_gmm, target_samples, particles_dict, 
        convergence_dict, results_df
    )
    
    return results_df, target_gmm, particles_dict, convergence_dict


def visualize_results(target_gmm, target_samples, particles_dict, 
                     convergence_dict, results_df):
    """
    Create visualizations of the results for comparison.
    
    Args:
        target_gmm: Target GMM distribution
        target_samples: Samples from target distribution
        particles_dict: Dictionary of particles from each method
        convergence_dict: Dictionary of convergence info for each method
        results_df: DataFrame with evaluation metrics
    """
    # Create visualizer
    viz = GMMVisualizer(cmap='viridis', figsize=(12, 10))
    
    # Get methods to visualize
    methods = list(particles_dict.keys())
    n_methods = len(methods)
    
    # Create folder for plots if it doesn't exist
    plots_dir = os.path.join(SCRIPT_DIR, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Figure 1: Distribution comparisons
    # First, visualize the target distribution
    try:
        fig_target, _ = viz.visualize_2d(target_gmm, title="Target 2D Correlated GMM Distribution", 
                                   show_components=True, n_samples=500)
        fig_target.savefig(os.path.join(plots_dir, "target_2d_distribution.png"), dpi=300)
    except Exception as e:
        print(f"Error visualizing target distribution: {e}")
    
    # Then visualize each method's approximation
    for method_name, particles in particles_dict.items():
        try:
            # Create scatter plot of particles
            plt.figure(figsize=(10, 8))
            
            # Plot target distribution contours
            x_min, x_max = np.min(target_samples[:, 0]) - 1, np.max(target_samples[:, 0]) + 1
            y_min, y_max = np.min(target_samples[:, 1]) - 1, np.max(target_samples[:, 1]) + 1
            xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100))
            grid_points = np.column_stack([xx.ravel(), yy.ravel()])
            
            # Evaluate log probabilities on grid for contour plot
            log_probs = target_gmm.log_prob(grid_points)
            probs = np.exp(log_probs).reshape(xx.shape)
            
            # Plot contours of target distribution
            plt.contour(xx, yy, probs, colors='k', alpha=0.5, linewidths=0.5)
            
            # Plot particles
            plt.scatter(particles[:, 0], particles[:, 1], s=10, alpha=0.6, c='red', label=method_name)
            
            # Plot component means and ellipses
            for i, (mean, cov) in enumerate(zip(target_gmm.means, target_gmm.covs)):
                # Plot mean
                plt.scatter(mean[0], mean[1], s=100, c='green', edgecolor='black', label=f'Mode {i+1}' if i == 0 else "")
                
                # Plot covariance ellipse (95% confidence region)
                # Compute eigenvalues and eigenvectors
                eigvals, eigvecs = np.linalg.eigh(cov)
                
                # Sort by eigenvalue in descending order
                idx = eigvals.argsort()[::-1]
                eigvals = eigvals[idx]
                eigvecs = eigvecs[:, idx]
                
                # Compute angle of first eigenvector
                angle = np.arctan2(eigvecs[1, 0], eigvecs[0, 0])
                
                # Confidence level (e.g., 95% confidence ellipse)
                chisquare_val = 5.991  # 95% confidence for 2 degrees of freedom
                
                # Width and height of ellipse
                width = 2 * np.sqrt(chisquare_val * eigvals[0])
                height = 2 * np.sqrt(chisquare_val * eigvals[1])
                
                # Plot the ellipse
                from matplotlib.patches import Ellipse
                ellipse = Ellipse(xy=mean, width=width, height=height,
                                angle=np.degrees(angle), edgecolor='green', fc='none',
                                lw=2, alpha=0.7)
                plt.gca().add_patch(ellipse)
                
            plt.title(f"{method_name} Approximation - Mode Coverage: {results_df.loc[method_name, 'Mode Coverage']:.2f}, " +
                    f"Corr. Error: {results_df.loc[method_name, 'Correlation Error']:.4f}")
            plt.xlabel('x₁')
            plt.ylabel('x₂')
            plt.grid(alpha=0.3)
            plt.axis('equal')
            plt.legend(loc='upper right')
            
            # Save figure
            plt.savefig(os.path.join(plots_dir, f"{method_name}_2d_approximation.png"), dpi=300)
            plt.close()
        except Exception as e:
            print(f"Error visualizing {method_name} approximation: {e}")
    
    # Figure 2: Convergence plots for methods with convergence info
    if convergence_dict:
        try:
            plt.figure(figsize=(12, 6 * len(convergence_dict)))
            
            for i, (method_name, convergence) in enumerate(convergence_dict.items()):
                plt.subplot(len(convergence_dict), 1, i + 1)
                
                # Get convergence history
                history = convergence.get('delta_norm_history', [])
                if len(history) > 0:
                    # Clip extremely large values for better visualization
                    clipped_history = np.clip(history, 0, np.percentile(history, 95) * 2)
                    plt.plot(clipped_history, linewidth=2)
                    plt.xlabel('Iteration', fontsize=12)
                    plt.ylabel('Update Magnitude', fontsize=12)
                    plt.title(f'{method_name} Convergence (clipped)', fontsize=14)
                    plt.grid(True, alpha=0.3)
                else:
                    plt.text(0.5, 0.5, "No convergence data available", 
                            ha='center', va='center', fontsize=14)
            
            plt.tight_layout()
            plt.savefig(os.path.join(plots_dir, "convergence_2d.png"), dpi=300)
            plt.close()
        except Exception as e:
            print(f"Error plotting convergence: {e}")
    
    # Figure 3: Metrics comparison
    try:
        plt.figure(figsize=(14, 10))
        
        # Extract metrics for plotting
        metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
                'Mode Coverage', 'Correlation Error', 'ESS', 'Sliced Wasserstein']
        
        # Create bar plots for each metric
        for i, metric in enumerate(metrics):
            plt.subplot(3, 3, i+1)
            
            # Get values and create bar plot
            values = results_df[metric]
            
            # Clip very large values for better visualization
            if values.max() > 10 * values.min():
                print(f"Clipping extreme values for {metric}")
                values = values.clip(upper=values.median() * 3)
                
            bars = plt.bar(values.index, values.values)
            
            # Add value annotations
            for bar in bars:
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{height:.4f}', ha='center', va='bottom', rotation=0,
                        fontsize=9)
            
            plt.title(metric)
            plt.xticks(rotation=45)
            plt.grid(axis='y', alpha=0.3)
            
            # For Mode Coverage and ESS, higher is better
            if metric in ['Mode Coverage', 'ESS']:
                plt.ylim(0, 1.1)
                plt.title(f"{metric} (higher is better)")
            # For error metrics, lower is better
            else:
                plt.title(f"{metric} (lower is better)")
        
        # Add runtime comparison
        plt.subplot(3, 3, 8)
        runtime_values = results_df['Runtime (s)']
        bars = plt.bar(runtime_values.index, runtime_values.values)
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{height:.2f}s', ha='center', va='bottom', rotation=0,
                    fontsize=9)
        plt.title("Runtime (seconds)")
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        # Add summary visualization
        plt.subplot(3, 3, 9)
        # Create a simple table with key metrics
        plt.axis('off')
        table_data = []
        table_data.append(['Method', 'Mode Cov.', 'Corr. Error'])
        for method in methods:
            table_data.append([
                method, 
                f"{results_df.loc[method, 'Mode Coverage']:.2f}", 
                f"{results_df.loc[method, 'Correlation Error']:.2f}"
            ])
        table = plt.table(cellText=table_data, loc='center', cellLoc='center', colWidths=[0.4, 0.3, 0.3])
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 1.5)
        plt.title("Summary Metrics")
        
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, "metrics_comparison_2d.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error plotting metrics comparison: {e}")
    
    # Create a combined visualization for paper-like presentation
    try:
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle("Comparison of Methods for 2D Multi-modal Correlated Distribution Approximation", 
                    fontsize=16)
        
        # Target distribution (larger plot)
        ax_target = plt.subplot2grid((2, 3), (0, 0), colspan=2)
        x_min, x_max = np.min(target_samples[:, 0]) - 1, np.max(target_samples[:, 0]) + 1
        y_min, y_max = np.min(target_samples[:, 1]) - 1, np.max(target_samples[:, 1]) + 1
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100))
        grid_points = np.column_stack([xx.ravel(), yy.ravel()])
        
        # Evaluate log probabilities on grid for contour plot
        log_probs = target_gmm.log_prob(grid_points)
        probs = np.exp(log_probs).reshape(xx.shape)
        
        # Plot contours of target distribution
        cs = ax_target.contour(xx, yy, probs, colors='k', alpha=0.7, linewidths=1.0)
        ax_target.scatter(target_samples[:, 0], target_samples[:, 1], s=10, alpha=0.3, c='blue')
        
        # Plot component means and ellipses
        for i, (mean, cov) in enumerate(zip(target_gmm.means, target_gmm.covs)):
            # Plot mean
            ax_target.scatter(mean[0], mean[1], s=100, c='red', edgecolor='black', 
                        label=f'Mode {i+1}' if i == 0 else "")
            
            # Plot covariance ellipse
            eigvals, eigvecs = np.linalg.eigh(cov)
            idx = eigvals.argsort()[::-1]
            eigvals = eigvals[idx]
            eigvecs = eigvecs[:, idx]
            angle = np.arctan2(eigvecs[1, 0], eigvecs[0, 0])
            width = 2 * np.sqrt(5.991 * eigvals[0])
            height = 2 * np.sqrt(5.991 * eigvals[1])
            
            ellipse = Ellipse(xy=mean, width=width, height=height,
                            angle=np.degrees(angle), edgecolor='red', fc='none',
                            lw=2, alpha=0.7)
            ax_target.add_patch(ellipse)
        
        ax_target.set_title("Target Distribution with Correlated Modes")
        ax_target.set_xlabel('x₁')
        ax_target.set_ylabel('x₂')
        ax_target.grid(alpha=0.3)
        ax_target.set_aspect('equal')
        
        # Method plots (one per available method, up to 4)
        method_positions = [(0, 2), (1, 0), (1, 1), (1, 2)]
        
        for i, method_name in enumerate(methods[:4]):  # Limit to 4 methods
            if i >= len(method_positions):
                break
                
            row, col = method_positions[i]
            ax = axes[row, col]
            
            particles = particles_dict[method_name]
            
            # Plot contours of target distribution
            ax.contour(xx, yy, probs, colors='k', alpha=0.3, linewidths=0.5)
            
            # Plot particles
            ax.scatter(particles[:, 0], particles[:, 1], s=8, alpha=0.6, c='red')
            
            # Set title with metrics
            ax.set_title(f"{method_name}\nMode Cov: {results_df.loc[method_name, 'Mode Coverage']:.2f}, " + 
                    f"Corr Err: {results_df.loc[method_name, 'Correlation Error']:.2f}")
            ax.set_xlabel('x₁')
            ax.set_ylabel('x₂')
            ax.grid(alpha=0.3)
            ax.set_aspect('equal')
        
        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust for the suptitle
        plt.savefig(os.path.join(plots_dir, "combined_visualization_2d.png"), dpi=300)
        plt.close()
    except Exception as e:
        print(f"Error creating combined visualization: {e}")
    
    print(f"Visualizations saved to {plots_dir}")


# ================================
# Run the experiment
# ================================

if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description='ESCORT Framework Evaluation on 2D Multi-Modal Correlated GMM (Fixed Version)')
    parser.add_argument('--methods', nargs='+', default=['ESCORT', 'SVGD', 'DVRL', 'SIR'], # 'ESCORT', 'SVGD', 'DVRL', 'SIR'
                    help='Methods to evaluate (default: ESCORT SVGD DVRL SIR)')
    parser.add_argument('--n_iter', type=int, default=300, 
                    help='Number of iterations (default: 300)')
    parser.add_argument('--step_size', type=float, default=0.01,  # Reduced from 0.05
                    help='Step size for updates (default: 0.01)')
    parser.add_argument('--no_verbose', action='store_false', dest='verbose',
                    help='Disable verbose output (default: verbose enabled)')
    parser.add_argument('--lambda_reg', type=float, default=0.005,  # Reduced from 0.01
                    help='GSWD regularization weight (default: 0.005)')
    parser.add_argument('--noise_level', type=float, default=0.1,  # Increased from 0.05
                    help='Noise level for exploration (default: 0.1)')
    parser.add_argument('--enhanced_repulsion', action='store_true', default=True,  # Changed default
                    help='Enable enhanced repulsion forces (default: True)')
    parser.add_argument('--noise_decay', type=float, default=0.98,  # Changed from 0.95
                    help='Noise decay rate (default: 0.98)')
    parser.add_argument('--h_dim', type=int, default=64, 
                    help='Dimension of RNN hidden state (for DVRL, default: 64)')
    parser.add_argument('--z_dim', type=int, default=2, 
                    help='Dimension of stochastic latent state (for DVRL, default: 2 for 2D problems)')
    parser.add_argument('--resample_threshold', type=float, default=0.5,
                    help='Threshold for resampling in SIR (default: 0.5)')
    parser.add_argument('--n_particles', type=int, default=1000,
                    help='Number of particles to use (default: 1000)')
    parser.add_argument('--dynamic_bandwidth', action='store_true', default=True,  # Changed default
                    help='Enable dynamic bandwidth for kernel (default: True)')
    parser.add_argument('--min_bandwidth', type=float, default=0.1,
                    help='Minimum bandwidth constraint (default: 0.1)')
    parser.add_argument('--mode_balancing', action='store_true', default=True,  # Changed default
                    help='Enable mode balancing for multi-modal distributions (default: True)')
    parser.add_argument('--adaptive_lambda', action='store_true', default=True,  # Changed default
                    help='Enable adaptive lambda for regularization (default: True)')
    parser.add_argument('--gradient_clip', type=float, default=10.0,
                    help='Gradient clipping threshold (default: 10.0)')
        
    # Parse arguments
    args = parser.parse_args()
    
    # Configure parameters with improved defaults
    method_params = {
        'n_iter': args.n_iter,
        'step_size': args.step_size,
        'verbose': args.verbose,
        'lambda_reg': args.lambda_reg,
        'noise_level': args.noise_level,
        'enhanced_repulsion': args.enhanced_repulsion,
        'noise_decay': args.noise_decay,
        'h_dim': args.h_dim,
        'z_dim': 2,  # Force z_dim to 2 for 2D problem
        'n_particles': args.n_particles,
        'resample_threshold': args.resample_threshold,
        'dynamic_bandwidth': args.dynamic_bandwidth,
        'min_bandwidth': args.min_bandwidth,
        'mode_balancing': args.mode_balancing,
        'adaptive_lambda': args.adaptive_lambda,
        'gradient_clip': args.gradient_clip
    }
    
    # Set global variable for runtime tracking
    methods_runtime = {}
    
    # Run the experiment with specified methods
    results_df, target_gmm, particles_dict, convergence_dict = run_experiment(
        methods_to_run=args.methods,
        **method_params
    )
    
    # Save results to CSV in the script directory
    results_dir = os.path.join(SCRIPT_DIR, "results")
    os.makedirs(results_dir, exist_ok=True)
    results_df.to_csv(os.path.join(results_dir, "escort_2d_results.csv"))
    
    print("\nExperiment complete. Results saved to CSV and visualizations saved in the plots directory.")
    print(f"Results saved in: {results_dir}")
    print(f"Visualizations saved in: {os.path.join(SCRIPT_DIR, 'plots')}")
