"""
    ESCORT Framework Evaluation on 1D Multi-Modal GMM
    
    This script tests the ESCORT framework against a challenging 1D GMM distribution
    with multiple modes of different weights and variances.
    
    This version uses multiple random seeds and reports metrics as mean ± standard error.
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde, wasserstein_distance, sem
from sklearn.metrics.pairwise import rbf_kernel
import pandas as pd
import time
import os
import argparse
import torch
import torch.optim as optim

from dvrl.dvrl import DVRL
from belief_assessment.distributions import GMMDistribution
from escort.svgd import SVGD, AdaptiveSVGD  
from belief_assessment.evaluation.visualize_distributions import GMMVisualizer

# Get the directory of the current script for saving outputs
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))


class DVRLAdapter:
    """
    # & Adapter to use DVRL for 1D distribution approximation
    """
    def __init__(self, 
                obs_dim=1, 
                action_dim=1, 
                h_dim=64, 
                z_dim=32, 
                n_particles=30,
                learning_rate=1e-4,
                n_iter=300,
                device=None):
        # & Initialize parameters
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.n_particles = n_particles
        self.n_iter = n_iter
        self.learning_rate = learning_rate
        
        # & Set device
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # & Create DVRL model
        self.model = DVRL(
            obs_dim=obs_dim,
            action_dim=action_dim,
            h_dim=h_dim,
            z_dim=z_dim,
            n_particles=n_particles,
            continuous_actions=True
        ).to(self.device)
        
        # & Set up optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        
        # & Track convergence
        self.convergence = {'delta_norm_history': []}
        
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Train DVRL on the 1D distribution and return final particles
        """
        # & Convert initial particles to torch tensor
        initial_particles = torch.FloatTensor(initial_particles).to(self.device)
        n_particles = initial_particles.shape[0]
        
        # & Initialize belief state
        batch_size = 1  # Single batch for this test case
        h_init, z_init, w_init = self.model.init_belief(batch_size, self.device)
        
        # Get the actual dimensions of z_init to determine the projection layer input size
        _, actual_n_particles, actual_z_dim = z_init.shape
        print(f"DVRL belief shapes: h={h_init.shape}, z={z_init.shape}, w={w_init.shape}")
        
        # & Create a simple projection layer to map from actual z-space to target space (1D)
        # Note: We use the actual_z_dim instead of self.z_dim to handle possible dimension differences
        projection_layer = torch.nn.Linear(actual_z_dim, 1).to(self.device)
        
        # & Initialize projection layer with random weights
        torch.nn.init.normal_(projection_layer.weight, mean=0.0, std=0.1)
        torch.nn.init.zeros_(projection_layer.bias)
        
        # & Add projection parameters to the optimizer
        optimizer = optim.Adam(list(self.model.parameters()) + list(projection_layer.parameters()), 
                            lr=self.learning_rate)
        
        # & Use the original particles from init_belief
        h_particles = h_init.clone()
        z_particles = z_init.clone()
        weights = w_init.clone()
        
        # & Create a representative observation tensor with proper shape
        observation = torch.zeros(batch_size, self.obs_dim).to(self.device)
        
        # & Create dummy actions
        dummy_action = torch.zeros(batch_size, self.action_dim).to(self.device)
        
        # & Training loop
        for i in range(self.n_iter):
            optimizer.zero_grad()
            
            # Set observation to the mean of initial particles for simplicity
            observation[0] = torch.mean(initial_particles, dim=0)
            
            # & Update belief with observation
            # Clone tensors to ensure they're fresh each iteration
            h_particles_clone = h_particles.clone()
            z_particles_clone = z_particles.clone()
            weights_clone = weights.clone()
            
            # & Update belief with observation - collect all return values
            update_results = self.model.update_belief(
                h_particles_clone, z_particles_clone, weights_clone, dummy_action, observation
            )
            
            # & Unpack the return values
            h_particles_new = update_results[0]
            z_particles_new = update_results[1]
            
            # & Project z_particles to 1D space to match target GMM
            # Make sure we're using the right shape
            z_projected = projection_layer(z_particles_new.squeeze(0))  # Shape: [n_particles, 1]
            
            # & Get projected particles for score computation
            z_projected_np = z_projected.detach().cpu().numpy()
            
            # & Compute scores using the provided score_fn
            scores = score_fn(z_projected_np)
            scores_tensor = torch.tensor(scores, device=self.device, dtype=torch.float32)
            
            # & Create the loss - maximize score (minimize negative score)
            loss = -torch.mean(z_projected * scores_tensor.unsqueeze(1))
            
            # & Add regularization to encourage diversity in particles
            diversity_loss = -0.1 * torch.mean(torch.pdist(z_projected))
            loss += diversity_loss
            
            # & Backpropagate and update
            loss.backward()
            optimizer.step()
            
            # & Update particles for next iteration
            h_particles = h_particles_new.detach()
            z_particles = z_particles_new.detach()
            weights = weights_clone.detach()
            
            # & Track delta norm for convergence
            with torch.no_grad():
                delta_norm = torch.norm(z_particles - z_particles_new.detach()).item()
                self.convergence['delta_norm_history'].append(delta_norm)
        
        # & For final evaluation, project the z_particles to 1D space
        with torch.no_grad():
            z_projected_final = projection_layer(z_particles.squeeze(0))
            final_particles = z_projected_final.detach().cpu().numpy()
        
        if return_convergence:
            return final_particles, self.convergence
        return final_particles
    

class SIRAdapter:
    """
    # & Adapter to use Sequential Importance Resampling (SIR) Particle Filter
    # & for 1D distribution approximation
    """
    def __init__(self, n_particles=1000, n_iter=300, resample_threshold=0.5, verbose=True):
        # & Initialize parameters
        self.n_particles = n_particles
        self.n_iter = n_iter
        self.resample_threshold = resample_threshold  # Resample when ESS < threshold * n_particles
        self.verbose = verbose
        
        # & Track convergence
        self.convergence = {'delta_norm_history': []}
    
    def fit_transform(self, initial_particles, score_fn, target_samples=None, return_convergence=False):
        """
        # & Run SIR particle filter to approximate the target distribution
        """
        # & Use provided initial particles
        particles = initial_particles.copy()
        n_particles = len(particles)
        
        # & Initialize uniform weights
        weights = np.ones(n_particles) / n_particles
        
        # & Main SIR loop
        prev_particles = particles.copy()
        for i in range(self.n_iter):
            # & Get log probabilities for current particles
            _, log_weights = score_fn(particles, return_logp=True)
            
            # & Update weights (using importance sampling)
            weights = np.exp(log_weights - np.max(log_weights))
            weights = weights / np.sum(weights)
            
            # & Calculate effective sample size
            ess = 1.0 / np.sum(weights**2)
            normalized_ess = ess / n_particles
            
            # & Resample if ESS is too low
            if normalized_ess < self.resample_threshold:
                if self.verbose and i % 20 == 0:
                    print(f"Iteration {i}: Resampling (ESS = {normalized_ess:.4f})")
                
                # & Systematic resampling
                indices = self._systematic_resample(weights)
                particles = particles[indices]
                
                # & Reset weights to uniform
                weights = np.ones(n_particles) / n_particles
                
                # & Add small noise to avoid particle collapse
                particles += np.random.normal(0, 0.1, particles.shape)
            
            # & Store delta norm for convergence tracking
            delta_norm = np.linalg.norm(particles - prev_particles) / n_particles
            self.convergence['delta_norm_history'].append(delta_norm)
            
            prev_particles = particles.copy()
        
        if return_convergence:
            return particles, self.convergence
        return particles
    
    def _systematic_resample(self, weights):
        """
        # & Perform systematic resampling
        """
        n = len(weights)
        positions = (np.random.random() + np.arange(n)) / n
        
        indices = np.zeros(n, dtype=int)
        cumulative_sum = np.cumsum(weights)
        i, j = 0, 0
        
        while i < n:
            if positions[i] < cumulative_sum[j]:
                indices[i] = j
                i += 1
            else:
                j += 1
        
        return indices


# ================================
# Define evaluation metrics
# ================================

def compute_mmd(X, Y, gamma=0.5):
    """
    Compute Maximum Mean Discrepancy between samples X and Y.
    
    Args:
        X (np.ndarray): First sample set
        Y (np.ndarray): Second sample set
        gamma (float): RBF kernel parameter
        
    Returns:
        float: MMD value
    """
    X = X.reshape(-1, 1)
    Y = Y.reshape(-1, 1)
    
    XX = rbf_kernel(X, X, gamma)
    YY = rbf_kernel(Y, Y, gamma)
    XY = rbf_kernel(X, Y, gamma)
    
    return np.mean(XX) + np.mean(YY) - 2 * np.mean(XY)

def estimate_kl_divergence(p_samples, q_kde, n_bins=100):
    """
    Estimate KL(P||Q) from samples of P and a KDE of Q.
    
    Args:
        p_samples (np.ndarray): Samples from distribution P
        q_kde (scipy.stats.gaussian_kde): KDE of distribution Q
        n_bins (int): Number of bins for histogram approximation
        
    Returns:
        float: Estimated KL divergence
    """
    # Create bins based on data range
    min_val = min(p_samples.min(), q_kde.dataset.min())
    max_val = max(p_samples.max(), q_kde.dataset.max())
    bins = np.linspace(min_val, max_val, n_bins+1)
    
    # Compute histogram for p_samples
    p_hist, _ = np.histogram(p_samples, bins=bins, density=True)
    
    # Evaluate q at bin centers
    bin_centers = 0.5 * (bins[1:] + bins[:-1])
    q_val = q_kde(bin_centers)
    
    # Add small epsilon to avoid division by zero or log of zero
    epsilon = 1e-10
    p_hist = p_hist + epsilon
    q_val = q_val + epsilon
    
    # Normalize
    p_hist = p_hist / np.sum(p_hist)
    q_val = q_val / np.sum(q_val)
    
    # Compute KL divergence
    kl = np.sum(p_hist * np.log(p_hist / q_val)) * (bins[1] - bins[0])
    
    return kl

def compute_mode_coverage(samples, mode_means, threshold=1.0):
    """
    Compute the fraction of modes that are adequately covered by samples.
    
    Args:
        samples (np.ndarray): Sample points to evaluate
        mode_means (np.ndarray): Means of the target modes
        threshold (float): Distance threshold to consider a mode covered
        
    Returns:
        float: Mode coverage ratio in [0,1]
    """
    samples = samples.flatten()
    mode_counts = np.zeros(len(mode_means))
    
    # Count samples near each mode
    for sample in samples:
        distances = np.abs(sample - mode_means)
        closest_mode = np.argmin(distances)
        if distances[closest_mode] < threshold:
            mode_counts[closest_mode] += 1
    
    # Modes are covered if they have at least 5% of expected samples
    expected_per_mode = len(samples) / len(mode_means)
    covered_modes = np.sum(mode_counts > 0.05 * expected_per_mode)
    
    return covered_modes / len(mode_means)

def compute_ess(particles, target_distribution):
    """
    Compute normalized effective sample size of particles.
    
    Args:
        particles (np.ndarray): Particle set
        target_distribution: Target distribution with log_prob method
        
    Returns:
        float: Normalized ESS in [0,1]
    """
    # Compute log probabilities of particles under target distribution
    log_probs = target_distribution.log_prob(particles)
    
    # Convert to probabilities (weights)
    max_log_prob = np.max(log_probs)
    probs = np.exp(log_probs - max_log_prob)
    
    # Normalize weights
    weights = probs / np.sum(probs)
    
    # Compute ESS
    ess = 1.0 / np.sum(weights**2)
    
    # Normalize by number of particles
    normalized_ess = ess / len(particles)
    
    return normalized_ess

# Note: Correlation error isn't applicable for 1D distributions since there are no correlations 
# between dimensions. This metric will be used in higher dimensional test cases.

def evaluate_method(method_name, particles, target_samples, target_gmm, mode_means, target_kde):
    """
    Evaluate a particle-based approximation using multiple metrics.
    
    Args:
        method_name (str): Name of the method
        particles (np.ndarray): Particles from the method
        target_samples (np.ndarray): Samples from target distribution
        target_gmm (GMMDistribution): Target GMM distribution
        mode_means (np.ndarray): Means of target modes
        target_kde (scipy.stats.gaussian_kde): KDE of target distribution
    
    Returns:
        dict: Dictionary of metric results
    """
    particles_flat = particles.flatten()
    target_flat = target_samples.flatten()
    method_kde = gaussian_kde(particles_flat)
    
    # Compute all metrics
    results = {
        "Method": method_name,
        "MMD": compute_mmd(target_flat, particles_flat),
        "KL(Target||Method)": estimate_kl_divergence(target_flat, method_kde),
        "KL(Method||Target)": estimate_kl_divergence(particles_flat, target_kde),
        "Mode Coverage": compute_mode_coverage(particles, mode_means, threshold=1.0),
        "ESS": compute_ess(particles, target_gmm),
        "Wasserstein": wasserstein_distance(target_flat, particles_flat),
        "Runtime (s)": methods_runtime.get(method_name, 0)
    }
    
    return results


# ================================
# Define the target distribution
# ================================

def create_target_distribution():
    """
    Create a challenging 1D GMM with three modes of different weights and variances.
    
    Returns:
        GMMDistribution: Target 1D distribution
    """
    means = np.array([[-3.0], [0.0], [3.0]])  # Three distinct modes
    covs = np.array([[[0.8]], [[0.5]], [[0.5]]])  # Different variances
    weights = np.array([0.3, 0.4, 0.3])  # Uneven weights
    
    return GMMDistribution(means, covs, weights, name="Target 1D GMM")


# ================================
# Experiment setup
# ================================

def get_method(method_name, **kwargs):
    """
    Factory function to create method instances based on name.
    
    Args:
        method_name (str): Name of the method to create
        **kwargs: Additional arguments for method configuration
        
    Returns:
        object: Method instance
    """
    if method_name.upper() == "ESCORT":
        return AdaptiveSVGD(
            step_size=kwargs.get('step_size', 0.05),
            n_iter=kwargs.get('n_iter', 300),
            lambda_reg=0.0,                      # Completely disable GSWD
            dynamic_bandwidth=False,             # Use static bandwidth
            enhanced_repulsion=False,            # Disable enhanced repulsion
            noise_level=0.01,                    # Minimal noise
            noise_decay=0.9,                     # Fast decay
            mode_balancing=False,                # Disable mode balancing
            adaptive_lambda=False,               # Disable lambda adaptation
            verbose=kwargs.get('verbose', True)
        )
    elif method_name.upper() == "SVGD":
        return SVGD(
            step_size=kwargs.get('step_size', 0.05),
            n_iter=kwargs.get('n_iter', 200),
            lambda_reg=kwargs.get('lambda_reg', 0.0),
            dynamic_bandwidth=kwargs.get('dynamic_bandwidth', False),
            enhanced_repulsion=kwargs.get('enhanced_repulsion', False),
            noise_level=kwargs.get('noise_level', 0.0),
            mode_balancing=kwargs.get('mode_balancing', False),
            verbose=kwargs.get('verbose', True)
        )
    elif method_name.upper() == "DVRL":
        # Adjust the number of particles to match the input particles
        n_particles = kwargs.get('n_particles', 1000)
        return DVRLAdapter(
            obs_dim=1,
            action_dim=1,
            h_dim=kwargs.get('h_dim', 64),
            z_dim=kwargs.get('z_dim', 32),
            n_particles=n_particles,
            learning_rate=kwargs.get('learning_rate', 1e-4),
            n_iter=kwargs.get('n_iter', 300)
        )
    elif method_name.upper() == "SIR":
        return SIRAdapter(
            n_particles=kwargs.get('n_particles', 1000),
            n_iter=kwargs.get('n_iter', 300),
            resample_threshold=kwargs.get('resample_threshold', 0.5),
            verbose=kwargs.get('verbose', True)
        )
    else:
        raise ValueError(f"Unknown method: {method_name}")

def run_experiment_with_multiple_seeds(methods_to_run=None, n_runs=5, seeds=None, **kwargs):
    """
    Run the full 1D evaluation experiment with multiple random seeds.
    
    Args:
        methods_to_run (list): List of method names to run. If None, runs all methods.
        n_runs (int): Number of runs with different seeds.
        seeds (list): List of seeds to use. If None, random seeds will be generated.
        **kwargs: Additional arguments for method configuration
        
    Returns:
        tuple: (DataFrame with mean results, DataFrame with all results, Dictionary with all particles)
    """
    print(f"Starting 1D GMM evaluation experiment with {n_runs} different initializations...")
    
    # Set default methods if not specified
    if methods_to_run is None:
        methods_to_run = ["ESCORT", "SVGD", "DVRL", "SIR"]
    
    # Generate random seeds if not provided
    if seeds is None:
        master_seed = np.random.randint(0, 10000)
        print(f"Master seed: {master_seed}")
        np.random.seed(master_seed)
        seeds = np.random.randint(0, 10000, size=n_runs)
    
    # Dictionary to store results for each method and run
    all_results = []
    
    # Dictionary to store particles for each method and run
    all_particles = {method: [] for method in methods_to_run}
    
    # Dictionary to store runtime for each method and run
    global methods_runtime
    methods_runtime = {}
    
    # Set a fixed seed for the target distribution - we want this to be the same across all runs
    # This creates a controlled experiment where only the initializations vary
    np.random.seed(42)
    
    # Create target distribution (same for all runs)
    target_gmm = create_target_distribution()
    
    # Extract mode means (same for all runs)
    mode_means = np.array([-3.0, 0.0, 3.0])
    
    # Generate target samples for evaluation (same for all runs)
    n_eval_samples = 2000
    target_samples = target_gmm.sample(n_eval_samples)
    target_kde = gaussian_kde(target_samples.flatten())
    
    # Define score function for particle updates (same for all runs)
    def score_fn(x, return_logp=False):
        """Score function (gradient of log density) for the target GMM."""
        scores = target_gmm.score(x)
        if return_logp:
            log_probs = target_gmm.log_prob(x)
            return scores, log_probs
        return scores
    
    # Run the experiment multiple times with different initializations
    for run_idx, seed in enumerate(seeds):
        print(f"\n=== Run {run_idx+1}/{n_runs} (Initialization Seed: {seed}) ===")
        
        # Set the random seed for this run's initialization only
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
        
        # Different initialization strategies for each run
        n_particles = 1000
        initialization_type = run_idx % 4  # Cycle through 4 different initialization strategies
        
        if initialization_type == 0:
            # Standard Gaussian initialization
            initial_particles = np.random.randn(n_particles, 1) * 2.0
            init_description = "Standard Gaussian"
        elif initialization_type == 1:
            # Uniform initialization
            initial_particles = np.random.uniform(-5, 5, (n_particles, 1))
            init_description = "Uniform [-5, 5]"
        elif initialization_type == 2:
            # Concentrated initialization
            center = np.random.choice([-3, 0, 3])  # Choose one of the mode centers
            initial_particles = np.random.randn(n_particles, 1) * 0.5 + center
            init_description = f"Concentrated around {center}"
        else:
            # Multimodal initialization
            modes = np.array([-4, -1, 2, 5])  # Different from target modes
            mode_idx = np.random.choice(len(modes), n_particles)
            initial_particles = modes[mode_idx].reshape(-1, 1) + np.random.randn(n_particles, 1) * 0.5
            init_description = "Custom multimodal"
            
        print(f"  Using {init_description} initialization")
        
        # For each method, add small random perturbations to initial particles
        for method_name in methods_to_run:
            print(f"  Running {method_name}...")
            
            # Add method-specific perturbation to make each method start slightly differently
            method_particles = initial_particles.copy()
            method_particles += np.random.randn(n_particles, 1) * 0.1
            
            # Create method instance
            method = get_method(method_name, **kwargs)
            
            # Time the execution
            start_time = time.time()
            
            # Run the method
            particles = method.fit_transform(
                method_particles, 
                score_fn,
                target_samples=target_samples,
                return_convergence=False
            )
            
            methods_runtime[method_name] = time.time() - start_time
            
            # Store particles for this method and run
            all_particles[method_name].append(particles)
            
            # Evaluate and store results
            method_results = evaluate_method(
                method_name, particles, target_samples, 
                target_gmm, mode_means, target_kde
            )
            
            # Add run information
            method_results["Run"] = run_idx + 1
            method_results["Seed"] = seed
            method_results["Initialization"] = init_description
            
            all_results.append(method_results)
    
    # Convert all results to a DataFrame
    all_results_df = pd.DataFrame(all_results)
    
    # Calculate mean and standard error for each metric and method
    metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
              'Mode Coverage', 'ESS', 'Wasserstein', 'Runtime (s)']
    
    # Initialize dictionary for mean results
    mean_results = []
    
    # Calculate statistics for each method
    for method in methods_to_run:
        method_data = all_results_df[all_results_df["Method"] == method]
        
        # Calculate mean and standard error for each metric
        method_stats = {"Method": method}
        
        for metric in metrics:
            values = method_data[metric].values
            mean_val = np.mean(values)
            se_val = sem(values)  # Standard error of the mean
            
            # Store mean and standard error
            method_stats[f"{metric}_mean"] = mean_val
            method_stats[f"{metric}_se"] = se_val
            method_stats[f"{metric}"] = f"{mean_val:.6f} ± {se_val:.6f}"
        
        mean_results.append(method_stats)
    
    # Convert to DataFrame
    mean_results_df = pd.DataFrame(mean_results)
    mean_results_df = mean_results_df.set_index('Method')
    
    # Print results
    print("\nResults Summary (Mean ± Standard Error):")
    display_cols = [metric for metric in metrics]
    print(mean_results_df[display_cols])
    
    return mean_results_df, all_results_df, all_particles, target_gmm, mode_means


def visualize_results_with_error_bars(mean_results_df, all_results_df, all_particles, target_gmm, mode_means):
    """
    Create visualizations of the results for comparison with error bars.
    
    Args:
        mean_results_df: DataFrame with mean and standard error for each method and metric
        all_results_df: DataFrame with results from all runs
        all_particles: Dictionary with particles from each method and run
        target_gmm: Target GMM distribution
        mode_means: Means of target modes
    """
    # Create visualizer
    viz = GMMVisualizer(cmap='viridis', figsize=(12, 8))
    
    # Get methods to visualize
    methods = list(mean_results_df.index)
    n_methods = len(methods)
    
    # Generate target samples for visualization - use fixed seed for consistent visualization
    n_viz_samples = 2000
    np.random.seed(42)  # Use a fixed seed for visualization samples
    target_samples = target_gmm.sample(n_viz_samples)
    
    # Figure 1: Distribution comparisons
    plt.figure(figsize=(15, 12))
    
    # Plot 1: Target distribution
    plt.subplot(n_methods + 1, 1, 1)
    viz.visualize_1d(target_gmm, title="Target 1D GMM Distribution", 
                    show_components=True, n_samples=500)
    
    # Setup for density plotting
    x = np.linspace(-7, 7, 1000)
    true_density = np.exp(target_gmm.log_prob(x.reshape(-1, 1)))
    
    # Plot each method using particles from the first run (for visualization)
    colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown']  # Colors for different methods
    
    for i, method_name in enumerate(methods):
        plt.subplot(n_methods + 1, 1, i + 2)
        
        # Get particles from all runs for visualization
        particles_list = all_particles[method_name]
        
        # Plot KDE for each run
        for run_idx, particles in enumerate(particles_list):
            if run_idx == 0:  # For the first run, also plot histogram
                particles_flat = particles.flatten()
                plt.hist(particles_flat, bins=50, density=True, alpha=0.3, color=colors[i % len(colors)], 
                        label=f'{method_name} Particles (Run 1)')
                
                # Add KDE for smoother visualization
                kde = gaussian_kde(particles_flat)
                plt.plot(x, kde(x), color=colors[i % len(colors)], linewidth=2, 
                         label=f'{method_name} KDE (Run 1)')
            else:
                # For other runs, just plot KDE as semi-transparent lines
                particles_flat = particles.flatten()
                kde = gaussian_kde(particles_flat)
                plt.plot(x, kde(x), color=colors[i % len(colors)], linewidth=1, alpha=0.3)
        
        # Add true density for comparison
        plt.plot(x, true_density, 'k--', linewidth=1.5, label='True Density')
        
        # Extract metrics with standard errors
        mmd = mean_results_df.loc[method_name, 'MMD']
        mode_coverage = mean_results_df.loc[method_name, 'Mode Coverage']
        
        # Add title with metrics (mean ± SE)
        plt.title(f"{method_name} Approximation\n{mmd} (MMD), {mode_coverage} (Mode Coverage)")
        plt.legend()
    
    plt.tight_layout()
    # Save in the script directory
    plt.savefig(os.path.join(SCRIPT_DIR, "escort_1d_comparison_with_errors.png"), dpi=300)
    
    # Figure 2: Metrics comparison with error bars
    plt.figure(figsize=(18, 12))
    
    # Extract metrics for plotting
    metrics = ['MMD', 'KL(Target||Method)', 'KL(Method||Target)', 
              'Mode Coverage', 'ESS', 'Wasserstein']
    
    # Create bar plots for each metric with error bars
    for i, metric in enumerate(metrics):
        plt.subplot(2, 3, i+1)
        
        # Extract means and standard errors
        means = [mean_results_df.loc[method, f"{metric}_mean"] for method in methods]
        errors = [mean_results_df.loc[method, f"{metric}_se"] for method in methods]
        
        # Create bar plot with error bars
        bars = plt.bar(methods, means, 
                      color=[colors[methods.index(m) % len(colors)] for m in methods],
                      yerr=errors, capsize=10, alpha=0.7)
        
        # Add value annotations
        for j, bar in enumerate(bars):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + errors[j] + 0.01,
                    f'{means[j]:.4f} ± {errors[j]:.4f}', ha='center', va='bottom', rotation=45,
                    fontsize=9)
        
        plt.title(metric)
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        # For Mode Coverage and ESS, higher is better
        if metric in ['Mode Coverage', 'ESS']:
            plt.ylim(0, 1.1)
    
    plt.tight_layout()
    # Save in the script directory
    plt.savefig(os.path.join(SCRIPT_DIR, "escort_1d_metrics_with_errors.png"), dpi=300)
    
    # Figure 3: Box plots showing distribution of results across runs
    plt.figure(figsize=(18, 12))
    
    for i, metric in enumerate(metrics):
        plt.subplot(2, 3, i+1)
        
        # Create box plot
        box_data = [all_results_df[all_results_df['Method'] == method][metric].values 
                   for method in methods]
        
        plt.boxplot(box_data, labels=methods, patch_artist=True,
                  boxprops=dict(facecolor='lightblue', color='blue'),
                  whiskerprops=dict(color='blue'),
                  capprops=dict(color='blue'),
                  medianprops=dict(color='red'))
        
        plt.title(f"{metric} Distribution Across Different Initializations")
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        
        # For Mode Coverage and ESS, higher is better
        if metric in ['Mode Coverage', 'ESS']:
            plt.ylim(0, 1.1)
    
    plt.tight_layout()
    # Save in the script directory
    plt.savefig(os.path.join(SCRIPT_DIR, "escort_1d_boxplots.png"), dpi=300)
    
    # Figure 4: Visualization by initialization type
    initializations = sorted(all_results_df['Initialization'].unique())
    if len(initializations) > 1:
        plt.figure(figsize=(18, 15))
        
        for i, metric in enumerate(['MMD', 'Mode Coverage', 'Wasserstein']):
            plt.subplot(3, 1, i+1)
            
            # Prepare data for grouped bar chart
            data = []
            for method in methods:
                method_data = []
                method_errors = []
                for init in initializations:
                    df_subset = all_results_df[(all_results_df['Method'] == method) & 
                                             (all_results_df['Initialization'] == init)]
                    if not df_subset.empty:
                        method_data.append(df_subset[metric].mean())
                        method_errors.append(sem(df_subset[metric]))
                    else:
                        method_data.append(0)
                        method_errors.append(0)
                data.append((method, method_data, method_errors))
            
            # Create grouped bar chart
            bar_width = 0.15
            r = np.arange(len(initializations))
            
            for j, (method, values, errors) in enumerate(data):
                position = [x + bar_width * j for x in r]
                bars = plt.bar(position, values, bar_width, 
                              label=method, 
                              color=colors[j % len(colors)],
                              yerr=errors, capsize=5)
            
            # Add labels and legend
            plt.xlabel('Initialization Type')
            plt.ylabel(metric)
            plt.title(f'Performance by Initialization Type - {metric}')
            plt.xticks([r + bar_width * (len(methods) - 1) / 2 for r in range(len(initializations))], 
                      initializations, rotation=45)
            plt.legend()
            plt.grid(axis='y', alpha=0.3)
            
            # For Mode Coverage, higher is better
            if metric == 'Mode Coverage':
                plt.ylim(0, 1.1)
                
        plt.tight_layout()
        plt.savefig(os.path.join(SCRIPT_DIR, "escort_1d_by_initialization.png"), dpi=300)
    
    # Show the plots
    plt.show()


# ================================
# Run the experiment
# ================================

if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description='ESCORT Framework Evaluation on 1D Multi-Modal GMM with Multiple Seeds')
    parser.add_argument('--methods', nargs='+', default=['ESCORT', 'SVGD', 'DVRL', 'SIR'],
                    help='Methods to evaluate (default: ESCORT SVGD DVRL SIR)')
    parser.add_argument('--n_runs', type=int, default=5,
                    help='Number of runs with different initializations (default: 5)')
    parser.add_argument('--n_iter', type=int, default=300, 
                    help='Number of iterations (default: 300)')
    parser.add_argument('--step_size', type=float, default=0.05,
                    help='Step size for updates (default: 0.05)')
    parser.add_argument('--no_verbose', action='store_false', dest='verbose',
                    help='Disable verbose output (default: verbose enabled)')
    parser.add_argument('--lambda_reg', type=float, default=0.01, 
                help='GSWD regularization weight (default: 0.01)')
    parser.add_argument('--noise_level', type=float, default=0.05, 
                    help='Noise level for exploration (default: 0.05)')
    parser.add_argument('--enhanced_repulsion', action='store_true', 
                    help='Enable enhanced repulsion forces (default: False)')
    parser.add_argument('--noise_decay', type=float, default=0.95,
                    help='Noise decay rate (default: 0.95)')
    # Add DVRL-specific arguments
    parser.add_argument('--h_dim', type=int, default=64, 
                    help='Dimension of RNN hidden state (for DVRL, default: 64)')
    parser.add_argument('--z_dim', type=int, default=32, 
                        help='Dimension of stochastic latent state (for DVRL, default: 32)')
    # Add SIR-specific arguments
    parser.add_argument('--resample_threshold', type=float, default=0.5,
                    help='Threshold for resampling in SIR (default: 0.5)')
    parser.add_argument('--n_particles', type=int, default=1000,
                    help='Number of particles to use for particle methods (default: 1000)')
    parser.add_argument('--fixed_seeds', action='store_true',
                    help='Use fixed seeds instead of random ones (default: False)')
    parser.add_argument('--initialization_types', nargs='+', 
                    choices=['gaussian', 'uniform', 'concentrated', 'multimodal', 'random'],
                    default=['random'],
                    help='Types of initialization to use (default: random)')
        
    # Parse arguments
    args = parser.parse_args()
    
    # Configure parameters
    method_params = {
        'n_iter': args.n_iter,
        'step_size': args.step_size,
        'verbose': args.verbose,
        'lambda_reg': args.lambda_reg,
        'noise_level': args.noise_level,
        'enhanced_repulsion': args.enhanced_repulsion,
        'noise_decay': args.noise_decay,
        'h_dim': args.h_dim,
        'z_dim': args.z_dim,
        # Add SIR parameters
        'n_particles': args.n_particles,
        'resample_threshold': args.resample_threshold
    }
    
    # Use fixed seeds if requested
    if args.fixed_seeds:
        seeds = [42, 123, 456, 789, 101]  # Fixed seeds for reproducibility
        seeds = seeds[:args.n_runs]  # Truncate if fewer runs requested
    else:
        # Generate random master seed
        master_seed = np.random.randint(0, 10000)
        print(f"Master seed: {master_seed}")
        
        # Use master seed to generate seeds for individual runs
        np.random.seed(master_seed)
        seeds = np.random.randint(0, 10000, size=args.n_runs)
    
    # Print seeds being used
    print(f"Using seeds: {seeds}")
    
    # Run the experiment with specified methods and seeds
    mean_results_df, all_results_df, all_particles, target_gmm, mode_means = run_experiment_with_multiple_seeds(
        methods_to_run=args.methods,
        n_runs=args.n_runs,
        seeds=seeds,
        **method_params
    )
    
    # Save results to CSV in the script directory
    mean_results_df.to_csv(os.path.join(SCRIPT_DIR, "escort_1d_mean_results.csv"))
    all_results_df.to_csv(os.path.join(SCRIPT_DIR, "escort_1d_all_results.csv"))
    
    # Visualize the results with error bars
    visualize_results_with_error_bars(mean_results_df, all_results_df, all_particles, target_gmm, mode_means)
    
    print("\nExperiment complete. Results saved to CSV and visualizations saved as PNG files.")
    print(f"Files saved in: {SCRIPT_DIR}")
    
    # Report the initialization types that were most challenging for each method
    print("\nPerformance by Initialization Type:")
    for method in args.methods:
        method_data = all_results_df[all_results_df['Method'] == method]
        
        # For each initialization type, get the mean MMD
        init_performance = {}
        for init in method_data['Initialization'].unique():
            init_data = method_data[method_data['Initialization'] == init]
            init_performance[init] = init_data['MMD'].mean()
        
        # Sort by performance (lower MMD is better)
        sorted_inits = sorted(init_performance.items(), key=lambda x: x[1])
        
        print(f"\n{method}:")
        print(f"  Best performance on: {sorted_inits[0][0]} (MMD: {sorted_inits[0][1]:.6f})")
        print(f"  Worst performance on: {sorted_inits[-1][0]} (MMD: {sorted_inits[-1][1]:.6f})")
