import os
import sys
from pathlib import Path
import numpy as np
from scipy.stats import truncnorm, ks_2samp
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend to avoid threading issues
import matplotlib.pyplot as plt
import time
from joblib import Parallel, delayed

# Add the project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# Set working directory to project root
os.chdir(project_root)

from bounds import cvar_estimator, stochastic_cvar_bound_const_eps

# Configuration parameters - modify these to change simulation behavior
ALPHA = 0.2  # CVaR alpha parameter - controls the tail probability for CVaR calculation
DELTA = 0.05  # Confidence parameter for bounds calculation
N_SAMPLES = 10000  # Number of samples to generate
N_RUNS = 100  # Number of runs for timing measurements
N_JOBS = -1  # Number of jobs for parallel execution (-1 = use all available CPU cores)

# Distribution parameters - modify these to change the underlying distributions
GMM_MEANS = [0.2, -0.2, -0.5, 0.5, 0]  # Means of the 5 GMM components
GMM_VARIANCES = [0.5, 0.2, 0.1, 0.1, 0.3]  # Variances of the 5 GMM components  
GMM_WEIGHTS = [0.6, 0.4, 0.1, 0.1, 0.8]  # Weights of the 5 GMM components (will be normalized)
TRUNCATION_LOWER = -1  # Lower bound for truncation
TRUNCATION_UPPER = 1   # Upper bound for truncation

np.random.seed(42)


def sample_truncated_gmm_1d(means, variances, weights, n_samples, lower, upper):
    """
    Sample from a 1D truncated GMM with custom means and variances.
    Optimized vectorized version for better performance.
    """
    means = np.array(means)
    variances = np.array(variances)
    n_components = len(means)
    if weights is None:
        weights = np.ones(n_components) / n_components
    else:
        weights = np.array(weights)
        weights = weights / np.sum(weights)
    
    # Choose components for each sample
    component_labels = np.random.choice(n_components, size=n_samples, p=weights)
    
    # Vectorized sampling: sample all components at once
    samples = np.zeros(n_samples)
    
    # For each component, sample all points that belong to it
    for component in range(n_components):
        mask = component_labels == component
        if np.any(mask):
            mean = means[component]
            std = np.sqrt(variances[component])
            a, b = -np.inf, np.inf
            if lower is not None:
                a = (lower - mean) / std
            if upper is not None:
                b = (upper - mean) / std
            # Vectorized sampling for this component
            samples[mask] = truncnorm.rvs(a, b, loc=mean, scale=std, size=np.sum(mask))
    
    return samples, component_labels

def sample_non_truncated_gmm_1d(means, variances, weights, n_samples):
    """
    Sample from a 1D non-truncated GMM for comparison.
    Optimized vectorized version for better performance.
    """
    means = np.array(means)
    variances = np.array(variances)
    n_components = len(means)
    if weights is None:
        weights = np.ones(n_components) / n_components
    else:
        weights = np.array(weights)
        weights = weights / np.sum(weights)
    
    # Choose components for each sample
    component_labels = np.random.choice(n_components, size=n_samples, p=weights)
    
    # Vectorized sampling: sample all components at once
    samples = np.zeros(n_samples)
    
    # For each component, sample all points that belong to it
    for component in range(n_components):
        mask = component_labels == component
        if np.any(mask):
            mean = means[component]
            std = np.sqrt(variances[component])
            # Vectorized sampling for this component
            samples[mask] = np.random.normal(mean, std, size=np.sum(mask))
    
    return samples, component_labels

def create_truncated_normal_approximation(means, variances, weights, lower, upper):
    """
    Create a truncated Normal approximation of the GMM by computing the mean and variance
    of the truncated GMM and fitting a single truncated Normal distribution.
    """
    means = np.array(means)
    variances = np.array(variances)
    weights = np.array(weights) / np.sum(weights)
    
    # Compute the mean and variance of the truncated GMM
    # For simplicity, we'll use the weighted mean and variance of the component means and variances
    # This is an approximation - for exact computation we'd need to account for truncation effects
    
    approx_mean = np.sum(weights * means)
    approx_var = np.sum(weights * (variances + means**2)) - approx_mean**2
    
    return approx_mean, np.sqrt(approx_var)

def sample_truncated_normal_approximation(mean, std, n_samples, lower, upper):
    """
    Sample from a truncated Normal distribution approximation.
    """
    a, b = -np.inf, np.inf
    if lower is not None:
        a = (lower - mean) / std
    if upper is not None:
        b = (upper - mean) / std
    return truncnorm.rvs(a, b, loc=mean, scale=std, size=n_samples)

def sample_non_truncated_normal_approximation(mean, std, n_samples):
    """
    Sample from a non-truncated Normal distribution approximation.
    """
    return np.random.normal(mean, std, size=n_samples)

def measure_sampling_time(sampling_func, *args, n_runs, **kwargs):
    """
    Measure the average sampling time for a given sampling function.
    """
    times = []
    samples = None
    for _ in range(n_runs):
        start_time = time.time()
        result = sampling_func(*args, **kwargs)
        end_time = time.time()
        times.append(end_time - start_time)
        
        # Handle different return types: GMM functions return (samples, labels), Normal functions return just samples
        if samples is None:
            if isinstance(result, tuple):
                samples = result[0]  # Extract samples from (samples, labels) tuple
            else:
                samples = result
    
    return np.mean(times), np.std(times), samples

def measure_sampling_time_parallel(sampling_func, *args, n_runs, n_jobs, **kwargs):
    """
    Measure the average sampling time for a given sampling function using parallel execution.
    """
    def single_run():
        start_time = time.time()
        result = sampling_func(*args, **kwargs)
        end_time = time.time()
        return end_time - start_time, result
    
    # Run sampling operations in parallel
    results = Parallel(n_jobs=n_jobs)(delayed(single_run)() for _ in range(n_runs))
    
    times = [result[0] for result in results]
    samples = None
    
    # Get the first sample result
    if results:
        sample_result = results[0][1]
        if isinstance(sample_result, tuple):
            samples = sample_result[0]  # Extract samples from (samples, labels) tuple
        else:
            samples = sample_result
    
    return np.mean(times), np.std(times), samples

def compute_ks_distance(samples1, samples2):
    """
    Compute the Kolmogorov-Smirnov distance between two samples.
    """
    statistic, p_value = ks_2samp(samples1, samples2)
    return statistic, p_value

def compute_multiple_ks_distances_parallel(sample_pairs, n_jobs):
    """
    Compute KS distances for multiple pairs of samples in parallel.
    
    Args:
        sample_pairs: List of tuples (samples1, samples2, name1, name2)
        n_jobs: Number of jobs for parallel execution
        
    Returns:
        List of tuples (statistic, p_value, name1, name2)
    """
    def compute_single_ks(pair):
        samples1, samples2, name1, name2 = pair
        statistic, p_value = compute_ks_distance(samples1, samples2)
        return statistic, p_value, name1, name2
    
    results = Parallel(n_jobs=n_jobs)(delayed(compute_single_ks)(pair=pair) for pair in sample_pairs)
    return results

def plot_sample_histograms(gmm_samples, normal_samples, bins, alpha):
    """
    Plot histograms of both samples in the same figure for comparison.
    """
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    
    # Plot GMM samples histogram
    plt.hist(gmm_samples, bins=bins, alpha=alpha, label='Truncated GMM', 
             color='blue', density=True, edgecolor='black')
    
    # Plot Normal approximation samples histogram
    plt.hist(normal_samples, bins=bins, alpha=alpha, label='Truncated Normal Approximation', 
             color='red', density=True, edgecolor='black')
    
    plt.xlabel('Value', fontsize=42, fontweight='bold')
    plt.ylabel('Density', fontsize=42, fontweight='bold')
    plt.title('Comparison of Truncated GMM vs Truncated Normal Approximation', 
              fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("sample_histograms")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved histogram plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory

def plot_sample_cdfs(gmm_samples, normal_samples):
    """
    Plot cumulative distribution functions (CDFs) of both samples for comparison.
    """
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    
    # Sort samples for CDF plotting
    gmm_sorted = np.sort(gmm_samples)
    normal_sorted = np.sort(normal_samples)
    
    # Create CDF values (empirical CDF)
    n_gmm = len(gmm_sorted)
    n_normal = len(normal_sorted)
    gmm_cdf = np.arange(1, n_gmm + 1) / n_gmm
    normal_cdf = np.arange(1, n_normal + 1) / n_normal
    
    # Plot CDFs
    plt.plot(gmm_sorted, gmm_cdf, 'b-', label='Truncated GMM CDF', linewidth=4)
    plt.plot(normal_sorted, normal_cdf, 'r-', label='Truncated Normal CDF', linewidth=4)
    
    plt.xlabel('Value', fontsize=42, fontweight='bold')
    plt.ylabel('Cumulative Probability', fontsize=42, fontweight='bold')
    plt.title('Cumulative Distribution Functions Comparison', 
              fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    
    # Set axis limits
    plt.xlim(min(gmm_sorted[0], normal_sorted[0]), max(gmm_sorted[-1], normal_sorted[-1]))
    plt.ylim(0, 1)
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("sample_cdfs")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved CDF plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory

def plot_cvar_convergence_with_bounds(gmm_samples, normal_samples, alpha, delta, eps, sample_sizes):
    """
    Plot CVaR estimator convergence for different sample sizes with bounds for Normal distribution.
    """
    if sample_sizes is None:
        # Dynamic range based on available samples
        max_samples = min(len(gmm_samples), len(normal_samples))
        sample_sizes = np.arange(100, max_samples + 1, max(1, max_samples // 20))  # ~20 points
    
    gmm_cvars = []
    normal_cvars = []
    normal_lower_bounds = []
    normal_upper_bounds = []
    
    # Get bounds for the Normal distribution
    y_inf = np.min(normal_samples)
    y_sup = np.max(normal_samples)
    x_inf = y_inf  # Assuming same bounds for x and y
    x_sup = y_sup
    
    for n in sample_sizes:
        # Use first n samples from each distribution
        gmm_cvar = cvar_estimator(gmm_samples[:n], alpha=alpha)
        normal_cvar = cvar_estimator(normal_samples[:n], alpha=alpha)
        
        # Calculate bounds for Normal distribution using stochastic bounds
        lower_bound, upper_bound = stochastic_cvar_bound_const_eps(
            y_samp=normal_samples[:n], 
            y_sup=y_sup, 
            y_inf=y_inf, 
            x_sup=x_sup, 
            x_inf=x_inf, 
            eps=eps, 
            delta=delta,
            alpha=alpha
        )
        
        gmm_cvars.append(gmm_cvar)
        normal_cvars.append(normal_cvar)
        normal_lower_bounds.append(lower_bound)
        normal_upper_bounds.append(upper_bound)
    
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    
    # Plot CVaR estimates
    plt.plot(sample_sizes, gmm_cvars, 'b-o', label='Truncated GMM CVaR', linewidth=4, markersize=10)
    plt.plot(sample_sizes, normal_cvars, 'r-s', label='Truncated Normal CVaR', linewidth=4, markersize=10)
    
    # Plot bounds for Normal distribution
    plt.fill_between(sample_sizes, normal_lower_bounds, normal_upper_bounds, 
                     alpha=0.3, color='red', label=f'Normal CVaR Bounds (δ={delta})')
    plt.plot(sample_sizes, normal_lower_bounds, 'r--', alpha=0.7, linewidth=1)
    plt.plot(sample_sizes, normal_upper_bounds, 'r--', alpha=0.7, linewidth=1)
    
    plt.xlabel('Number of Samples', fontsize=42, fontweight='bold')
    plt.ylabel(f'CVaR (α={alpha})', fontsize=42, fontweight='bold')
    plt.title(f'CVaR Bounds Convergence', 
              fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    plt.ylim(bottom=0)  # Set y-axis to start at 0
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("cvar_convergence_with_bounds")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved CVaR convergence with bounds plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory
    
    return gmm_cvars, normal_cvars, normal_lower_bounds, normal_upper_bounds

def plot_cvar_convergence_with_bounds_parallel(gmm_samples, normal_samples, alpha, delta, eps, sample_sizes, n_jobs):
    """
    Plot CVaR estimator convergence for different sample sizes with bounds for Normal distribution using parallel computation.
    """
    if sample_sizes is None:
        # Dynamic range based on available samples
        max_samples = min(len(gmm_samples), len(normal_samples))
        sample_sizes = np.arange(100, max_samples + 1, max(1, max_samples // 20))  # ~20 points
    
    # Get bounds for the Normal distribution
    y_inf = np.min(normal_samples)
    y_sup = np.max(normal_samples)
    x_inf = y_inf  # Assuming same bounds for x and y
    x_sup = y_sup
    
    def compute_cvar_and_bounds(n):
        """Compute CVaR and bounds for a given sample size n."""
        gmm_cvar = cvar_estimator(gmm_samples[:n], alpha=alpha)
        normal_cvar = cvar_estimator(normal_samples[:n], alpha=alpha)
        
        # Calculate bounds for Normal distribution using stochastic bounds
        lower_bound, upper_bound = stochastic_cvar_bound_const_eps(
            y_samp=normal_samples[:n], 
            y_sup=y_sup, 
            y_inf=y_inf, 
            x_sup=x_sup, 
            x_inf=x_inf, 
            eps=eps, 
            delta=delta,
            alpha=alpha
        )
        
        return gmm_cvar, normal_cvar, lower_bound, upper_bound
    
    # Compute all CVaR values and bounds in parallel
    results = Parallel(n_jobs=n_jobs)(delayed(compute_cvar_and_bounds)(n=n) for n in sample_sizes)
    
    # Unpack results
    gmm_cvars = [result[0] for result in results]
    normal_cvars = [result[1] for result in results]
    normal_lower_bounds = [result[2] for result in results]
    normal_upper_bounds = [result[3] for result in results]
    
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    
    # Plot CVaR estimates
    plt.plot(sample_sizes, gmm_cvars, 'b-o', label='Truncated GMM CVaR', linewidth=4, markersize=10)
    plt.plot(sample_sizes, normal_cvars, 'r-s', label='Truncated Normal CVaR', linewidth=4, markersize=10)
    
    # Plot bounds for Normal distribution
    plt.fill_between(sample_sizes, normal_lower_bounds, normal_upper_bounds, 
                     alpha=0.3, color='red', label=f'Normal CVaR Bounds (δ={delta})')
    plt.plot(sample_sizes, normal_lower_bounds, 'r--', alpha=0.7, linewidth=1)
    plt.plot(sample_sizes, normal_upper_bounds, 'r--', alpha=0.7, linewidth=1)
    
    plt.xlabel('Number of Samples', fontsize=42, fontweight='bold')
    plt.ylabel(f'CVaR (α={alpha})', fontsize=42, fontweight='bold')
    plt.title(f'CVaR Bounds Convergence', 
              fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    plt.ylim(bottom=0)  # Set y-axis to start at 0
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("cvar_convergence_with_bounds")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved CVaR convergence with bounds plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory
    
    return gmm_cvars, normal_cvars, normal_lower_bounds, normal_upper_bounds

def plot_cvar_convergence(gmm_samples, normal_samples, alpha, sample_sizes):
    """
    Plot CVaR estimator convergence for different sample sizes.
    """
    if sample_sizes is None:
        # Dynamic range based on available samples
        max_samples = min(len(gmm_samples), len(normal_samples))
        sample_sizes = np.arange(100, max_samples + 1, max(1, max_samples // 20))  # ~20 points
    
    gmm_cvars = []
    normal_cvars = []
    
    for n in sample_sizes:
        # Use first n samples from each distribution
        gmm_cvar = cvar_estimator(gmm_samples[:n], alpha=alpha)
        normal_cvar = cvar_estimator(normal_samples[:n], alpha=alpha)
        gmm_cvars.append(gmm_cvar)
        normal_cvars.append(normal_cvar)
    
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    plt.plot(sample_sizes, gmm_cvars, 'b-o', label='Truncated GMM', linewidth=4, markersize=10)
    plt.plot(sample_sizes, normal_cvars, 'r-s', label='Truncated Normal Approximation', linewidth=4, markersize=10)
    
    plt.xlabel('Number of Samples', fontsize=42, fontweight='bold')
    plt.ylabel(f'CVaR (α={alpha})', fontsize=42, fontweight='bold')
    plt.title('CVaR Estimator Convergence vs Sample Size', fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("cvar_convergence_simple")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved CVaR convergence plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory
    
    return gmm_cvars, normal_cvars

def plot_cvar_convergence_parallel(gmm_samples, normal_samples, alpha, sample_sizes, n_jobs):
    """
    Plot CVaR estimator convergence for different sample sizes using parallel computation.
    """
    if sample_sizes is None:
        # Dynamic range based on available samples
        max_samples = min(len(gmm_samples), len(normal_samples))
        sample_sizes = np.arange(100, max_samples + 1, max(1, max_samples // 20))  # ~20 points
    
    def compute_cvar(n):
        """Compute CVaR for a given sample size n."""
        gmm_cvar = cvar_estimator(gmm_samples[:n], alpha=alpha)
        normal_cvar = cvar_estimator(normal_samples[:n], alpha=alpha)
        return gmm_cvar, normal_cvar
    
    # Compute all CVaR values in parallel
    results = Parallel(n_jobs=n_jobs)(delayed(compute_cvar)(n=n) for n in sample_sizes)
    
    # Unpack results
    gmm_cvars = [result[0] for result in results]
    normal_cvars = [result[1] for result in results]
    
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    plt.plot(sample_sizes, gmm_cvars, 'b-o', label='Truncated GMM', linewidth=4, markersize=10)
    plt.plot(sample_sizes, normal_cvars, 'r-s', label='Truncated Normal Approximation', linewidth=4, markersize=10)
    
    plt.xlabel('Number of Samples', fontsize=42, fontweight='bold')
    plt.ylabel(f'CVaR (α={alpha})', fontsize=42, fontweight='bold')
    plt.title('CVaR Estimator Convergence vs Sample Size', fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("cvar_convergence_alt")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved CVaR convergence plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory
    
    return gmm_cvars, normal_cvars

def ensure_plots_directory():
    """Create the plots directory if it doesn't exist."""
    plots_dir = Path("./dist_disc_plots")
    plots_dir.mkdir(parents=True, exist_ok=True)
    return plots_dir

def get_plot_filename(base_name, extension="png"):
    """Generate a deterministic filename for saving plots (overwrites previous plots)."""
    plots_dir = ensure_plots_directory()
    return plots_dir / f"{base_name}.{extension}"

def list_saved_plots():
    """List all saved plots in the plots directory."""
    plots_dir = ensure_plots_directory()
    plot_files = list(plots_dir.glob("*.png"))
    if plot_files:
        print(f"\nSaved plots in {plots_dir}:")
        for plot_file in sorted(plot_files):
            print(f"  - {plot_file.name}")
    else:
        print(f"\nNo plots found in {plots_dir}")
    return plot_files

def plot_sampling_time_comparison(gmm_sampling_func, normal_sampling_func, gmm_args, normal_args, 
                                 sample_sizes, n_runs, n_jobs):
    """
    Plot sampling time comparison between GMM and Normal samples across different sample sizes.
    
    Args:
        gmm_sampling_func: Function to sample from GMM
        normal_sampling_func: Function to sample from Normal distribution
        gmm_args: Arguments for GMM sampling function
        normal_args: Arguments for Normal sampling function
        sample_sizes: List of sample sizes to test
        n_runs: Number of runs to average timing over
        n_jobs: Number of jobs for parallel execution
    """
    if sample_sizes is None:
        # Use a reasonable range for timing comparison
        sample_sizes = np.logspace(2, 4, 20).astype(int)  # 100 to 10000, 20 points
    
    def measure_sampling_time_for_size(sampling_func, args, n_samples, n_runs):
        """Measure sampling time for a specific sample size."""
        def single_run():
            start_time = time.time()
            # Create a wrapper that calls the sampling function with the correct arguments
            if sampling_func == sample_truncated_gmm_1d:
                sample_truncated_gmm_1d(
                    means=args['means'], 
                    variances=args['variances'], 
                    weights=args['weights'], 
                    n_samples=n_samples, 
                    lower=args['lower'], 
                    upper=args['upper']
                )
            elif sampling_func == sample_truncated_normal_approximation:
                sample_truncated_normal_approximation(
                    mean=args['mean'], 
                    std=args['std'], 
                    n_samples=n_samples, 
                    lower=args['lower'], 
                    upper=args['upper']
                )
            else:
                # Generic fallback
                if isinstance(args, dict):
                    args_copy = args.copy()
                    args_copy['n_samples'] = n_samples
                    sampling_func(**args_copy)
                else:
                    sampling_func(*args)
            end_time = time.time()
            return end_time - start_time
        
        # Run multiple times and average
        times = Parallel(n_jobs=n_jobs)(delayed(single_run)() for _ in range(n_runs))
        return np.mean(times), np.std(times) / np.sqrt(n_runs)
    
    gmm_times = []
    gmm_stds = []
    normal_times = []
    normal_stds = []
    
    print("Measuring sampling times across different sample sizes...")
    
    for i, n in enumerate(sample_sizes):
        if i % 5 == 0:  # Progress indicator
            print(f"  Processing sample size {n}/{sample_sizes[-1]}")
        
        # Measure GMM sampling time
        gmm_mean, gmm_std = measure_sampling_time_for_size(gmm_sampling_func, gmm_args, n, n_runs)
        gmm_times.append(gmm_mean)
        gmm_stds.append(gmm_std)
        
        # Measure Normal sampling time
        normal_mean, normal_std = measure_sampling_time_for_size(normal_sampling_func, normal_args, n, n_runs)
        normal_times.append(normal_mean)
        normal_stds.append(normal_std)
    
    # Create the plot
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    
    # Plot mean times with error bars for confidence intervals
    plt.errorbar(sample_sizes, gmm_times, yerr=gmm_stds, 
                fmt='b-o', label='GMM Sampling', linewidth=4, markersize=10, capsize=8)
    plt.errorbar(sample_sizes, normal_times, yerr=normal_stds, 
                fmt='r-s', label='Normal Sampling', linewidth=4, markersize=10, capsize=8)
    
    plt.xlabel('Number of Samples', fontsize=42, fontweight='bold')
    plt.ylabel('Sampling Time (seconds)', fontsize=42, fontweight='bold')
    plt.title(f'Sampling Time Comparison ({n_runs} runs)', fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    plt.xscale('log')
    plt.yscale('linear')  # Changed from 'log' to 'linear' for better visualization of time differences
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    # Add speedup ratio
    if len(sample_sizes) > 0:
        speedup_ratios = [g/n for g, n in zip(gmm_times, normal_times)]
        avg_speedup = np.mean(speedup_ratios)
        plt.text(0.05, 0.75, f'Average GMM/Normal ratio: {avg_speedup:.2f}x', 
                transform=plt.gca().transAxes, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
                fontsize=36, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("sampling_time_comparison")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved sampling time comparison plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory
    
    return gmm_times, gmm_stds, normal_times, normal_stds, sample_sizes

def plot_cvar_computation_time_comparison(gmm_sampling_func, normal_sampling_func, gmm_args, normal_args, 
                                         sample_sizes, alpha, delta, eps, n_runs, n_jobs):
    """
    Plot comparison of CVaR computation times: Normal bounds (sampling + bounds) vs GMM estimator (sampling + CVaR).
    
    Args:
        gmm_sampling_func: Function to sample from GMM
        normal_sampling_func: Function to sample from Normal distribution
        gmm_args: Arguments for GMM sampling function
        normal_args: Arguments for Normal sampling function
        sample_sizes: List of sample sizes to test
        alpha: CVaR alpha parameter
        delta: Confidence parameter for bounds
        eps: Distributional discrepancy parameter (KS distance)
        n_runs: Number of runs to average timing over
        n_jobs: Number of jobs for parallel execution
    """
    if sample_sizes is None:
        # Use a reasonable range for timing comparison
        sample_sizes = np.logspace(2, 4, 20).astype(int)  # 100 to 10000, 20 points
    
    def measure_gmm_cvar_time(args, n_samples, n_runs):
        """Measure time for GMM sampling + CVaR computation."""
        def single_run():
            start_time = time.time()
            # Sample from GMM
            if gmm_sampling_func == sample_truncated_gmm_1d:
                samples, _ = sample_truncated_gmm_1d(
                    means=args['means'], 
                    variances=args['variances'], 
                    weights=args['weights'], 
                    n_samples=n_samples, 
                    lower=args['lower'], 
                    upper=args['upper']
                )
            else:
                # Generic fallback
                if isinstance(args, dict):
                    args_copy = args.copy()
                    args_copy['n_samples'] = n_samples
                    samples = gmm_sampling_func(**args_copy)
                else:
                    samples = gmm_sampling_func(*args)
            
            # Compute CVaR
            cvar = cvar_estimator(samples, alpha=alpha)
            end_time = time.time()
            return end_time - start_time
        
        # Run multiple times and average
        times = Parallel(n_jobs=n_jobs)(delayed(single_run)() for _ in range(n_runs))
        return np.mean(times), np.std(times) / np.sqrt(n_runs)
    
    def measure_normal_bounds_time(args, n_samples, n_runs):
        """Measure time for Normal sampling + bounds computation."""
        def single_run():
            start_time = time.time()
            # Sample from Normal
            if normal_sampling_func == sample_truncated_normal_approximation:
                samples = sample_truncated_normal_approximation(
                    mean=args['mean'], 
                    std=args['std'], 
                    n_samples=n_samples, 
                    lower=args['lower'], 
                    upper=args['upper']
                )
            else:
                # Generic fallback
                if isinstance(args, dict):
                    args_copy = args.copy()
                    args_copy['n_samples'] = n_samples
                    samples = normal_sampling_func(**args_copy)
                else:
                    samples = normal_sampling_func(*args)
            
            # Compute bounds
            y_inf = np.min(samples)
            y_sup = np.max(samples)
            x_inf = y_inf
            x_sup = y_sup
            lower_bound, upper_bound = stochastic_cvar_bound_const_eps(
                y_samp=samples, 
                y_sup=y_sup, 
                y_inf=y_inf, 
                x_sup=x_sup, 
                x_inf=x_inf, 
                eps=eps, 
                delta=delta,
                alpha=alpha
            )
            end_time = time.time()
            return end_time - start_time
        
        # Run multiple times and average
        times = Parallel(n_jobs=n_jobs)(delayed(single_run)() for _ in range(n_runs))
        return np.mean(times), np.std(times) / np.sqrt(n_runs)
    
    gmm_times = []
    gmm_stds = []
    normal_times = []
    normal_stds = []
    
    print("Measuring CVaR computation times across different sample sizes...")
    
    for i, n in enumerate(sample_sizes):
        if i % 5 == 0:  # Progress indicator
            print(f"  Processing sample size {n}/{sample_sizes[-1]}")
        
        # Measure GMM CVaR computation time (sampling + CVaR)
        gmm_mean, gmm_std = measure_gmm_cvar_time(gmm_args, n, n_runs)
        gmm_times.append(gmm_mean)
        gmm_stds.append(gmm_std)
        
        # Measure Normal bounds computation time (sampling + bounds)
        normal_mean, normal_std = measure_normal_bounds_time(normal_args, n, n_runs)
        normal_times.append(normal_mean)
        normal_stds.append(normal_std)
    
    # Create the plot
    plt.figure(figsize=(16, 12))  # Increased figure size for larger text
    
    # Plot mean times with error bars for confidence intervals
    plt.errorbar(sample_sizes, gmm_times, yerr=gmm_stds, 
                fmt='b-o', label='GMM: Sampling + CVaR', linewidth=4, markersize=10, capsize=8)
    plt.errorbar(sample_sizes, normal_times, yerr=normal_stds, 
                fmt='r-s', label='Normal: Sampling + Bounds', linewidth=4, markersize=10, capsize=8)
    
    plt.xlabel('Number of Samples', fontsize=42, fontweight='bold')
    plt.ylabel('Computation Time (seconds)', fontsize=42, fontweight='bold')
    plt.title(f'CVaR Computation Time Comparison', fontsize=47, fontweight='bold', pad=20)
    plt.legend(prop={'size': 36, 'weight': 'bold'})
    plt.grid(True, alpha=0.3)
    plt.xscale('log')
    plt.yscale('log')
    
    # Increase tick label sizes
    plt.xticks(fontsize=31, fontweight='bold')
    plt.yticks(fontsize=31, fontweight='bold')
    
    # Add speedup ratio
    if len(sample_sizes) > 0:
        speedup_ratios = [g/n for g, n in zip(gmm_times, normal_times)]
        avg_speedup = np.mean(speedup_ratios)
        plt.text(0.05, 0.75, f'Average GMM/Normal ratio: {avg_speedup:.2f}x', 
                transform=plt.gca().transAxes, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
                fontsize=36, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = get_plot_filename("cvar_computation_time_comparison")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved CVaR computation time comparison plot to: {plot_filename}")
    
    # Don't show plot in non-interactive mode
    plt.close()  # Close the figure to free memory
    
    return gmm_times, gmm_stds, normal_times, normal_stds, sample_sizes

# Example usage:
if __name__ == "__main__":
    # Use global configuration parameters
    means = GMM_MEANS
    variances = GMM_VARIANCES
    weights = GMM_WEIGHTS
    lower = TRUNCATION_LOWER
    upper = TRUNCATION_UPPER
    
    n_samples = N_SAMPLES
    n_runs = N_RUNS
    alpha = ALPHA
    delta = DELTA
    n_jobs = N_JOBS
    
    # Show plots directory
    plots_dir = ensure_plots_directory()
    print(f"Plots will be saved to: {plots_dir}")
    print(f"Using configuration: ALPHA={ALPHA}, DELTA={DELTA}, N_SAMPLES={N_SAMPLES}, N_RUNS={N_RUNS}")
    print(f"GMM parameters: means={GMM_MEANS}, variances={GMM_VARIANCES}, weights={GMM_WEIGHTS}")
    print(f"Truncation bounds: [{TRUNCATION_LOWER}, {TRUNCATION_UPPER}]")
    
    print("Measuring sampling times...")
    
    # Measure truncated GMM sampling time
    print("Measuring GMM sampling times...")
    gmm_trunc_mean_time, gmm_trunc_std_time, gmm_samples = measure_sampling_time_parallel(
        sampling_func=sample_truncated_gmm_1d, 
        n_runs=n_runs, 
        n_jobs=n_jobs,
        means=means, 
        variances=variances, 
        weights=weights, 
        n_samples=n_samples, 
        lower=lower, 
        upper=upper
    )
    
    # Measure non-truncated GMM sampling time
    gmm_non_trunc_mean_time, gmm_non_trunc_std_time, gmm_non_trunc_samples = measure_sampling_time_parallel(
        sampling_func=sample_non_truncated_gmm_1d, 
        n_runs=n_runs, 
        n_jobs=n_jobs,
        means=means, 
        variances=variances, 
        weights=weights, 
        n_samples=n_samples
    )
    
    # Create truncated Normal approximation and measure sampling time
    approx_mean, approx_std = create_truncated_normal_approximation(
        means=means, 
        variances=variances, 
        weights=weights, 
        lower=lower, 
        upper=upper
    )
    
    # Measure truncated Normal sampling time
    print("Measuring Normal sampling times...")
    normal_trunc_mean_time, normal_trunc_std_time, normal_samples = measure_sampling_time_parallel(
        sampling_func=sample_truncated_normal_approximation, 
        n_runs=n_runs, 
        n_jobs=n_jobs,
        mean=approx_mean, 
        std=approx_std, 
        n_samples=n_samples,
        lower=lower, 
        upper=upper
    )
    
    # Measure non-truncated Normal sampling time
    normal_non_trunc_mean_time, normal_non_trunc_std_time, normal_non_trunc_samples = measure_sampling_time_parallel(
        sampling_func=sample_non_truncated_normal_approximation, 
        n_runs=n_runs, 
        n_jobs=n_jobs,
        mean=approx_mean, 
        std=approx_std, 
        n_samples=n_samples
    )
    
    # Compute KS distances between truncated and non-truncated samples
    gmm_ks_statistic, gmm_ks_p_value = compute_ks_distance(
        samples1=gmm_samples, 
        samples2=gmm_non_trunc_samples
    )
    normal_ks_statistic, normal_ks_p_value = compute_ks_distance(
        samples1=normal_samples, 
        samples2=normal_non_trunc_samples
    )
    
    # Demonstrate KS distance computation
    print("Computing KS distances...")
    sample_pairs = [
        (gmm_samples, gmm_non_trunc_samples, "GMM_trunc", "GMM_non_trunc"),
        (normal_samples, normal_non_trunc_samples, "Normal_trunc", "Normal_non_trunc"),
        (gmm_samples, normal_samples, "GMM_trunc", "Normal_trunc"),
        (gmm_non_trunc_samples, normal_non_trunc_samples, "GMM_non_trunc", "Normal_non_trunc")
    ]
    
    ks_results = compute_multiple_ks_distances_parallel(
        sample_pairs=sample_pairs, 
        n_jobs=n_jobs
    )
    
    # Use the KS distance as epsilon for bounds
    eps_gmm = gmm_ks_statistic
    eps_normal = normal_ks_statistic
    
    print(f"\nKS Distances (Truncated vs Non-truncated):")
    print(f"GMM KS distance: {eps_gmm:.4f}")
    print(f"Normal KS distance: {eps_normal:.4f}")
    
    print(f"\nKS Distance Results:")
    for statistic, p_value, name1, name2 in ks_results:
        print(f"{name1} vs {name2}: KS={statistic:.4f}, p-value={p_value:.4f}")
    
    # Compute CVaR for truncated samples only
    gmm_trunc_cvar = cvar_estimator(gmm_samples, alpha=alpha)
    normal_trunc_cvar = cvar_estimator(normal_samples, alpha=alpha)
    
    # Compute KS distance between GMM and Normal
    ks_statistic, ks_p_value = compute_ks_distance(
        samples1=gmm_samples, 
        samples2=normal_samples
    )
    
    # Use the KS distance between GMM and Normal as the epsilon for bounds
    eps_distributional = ks_statistic
    print(f"Distributional KS distance (GMM vs Normal): {eps_distributional:.4f}")
    
    print(f"\nSampling Time Results (averaged over {n_runs} runs):")
    print(f"Truncated GMM: {gmm_trunc_mean_time:.6f} ± {gmm_trunc_std_time:.6f} seconds")
    print(f"Non-truncated GMM: {gmm_non_trunc_mean_time:.6f} ± {gmm_non_trunc_std_time:.6f} seconds")
    print(f"Truncated Normal: {normal_trunc_mean_time:.6f} ± {normal_trunc_std_time:.6f} seconds")
    print(f"Non-truncated Normal: {normal_non_trunc_mean_time:.6f} ± {normal_non_trunc_std_time:.6f} seconds")
    
    print(f"\nSpeedup (Non-truncated vs Truncated):")
    print(f"GMM: {gmm_trunc_mean_time/gmm_non_trunc_mean_time:.2f}x slower")
    print(f"Normal: {normal_trunc_mean_time/normal_non_trunc_mean_time:.2f}x slower")
    
    print(f"\nCVaR Results (alpha={alpha}):")
    print(f"Truncated GMM CVaR: {gmm_trunc_cvar:.4f}")
    print(f"Truncated Normal CVaR: {normal_trunc_cvar:.4f}")
    print(f"CVaR Difference (GMM - Normal): {gmm_trunc_cvar - normal_trunc_cvar:.4f}")
    
    print(f"\nOther Results:")
    print(f"GMM samples shape: {gmm_samples.shape}")
    print(f"Normal approximation samples shape: {normal_samples.shape}")
    print(f"KS statistic (GMM vs Normal): {ks_statistic:.4f}")
    print(f"KS p-value (GMM vs Normal): {ks_p_value:.4f}")
    print(f"Approximation mean: {approx_mean:.4f}")
    print(f"Approximation std: {approx_std:.4f}")
    
    # Plot histograms
    plot_sample_histograms(
        gmm_samples=gmm_samples, 
        normal_samples=normal_samples, 
        bins=50, 
        alpha=0.7
    )
    
    # Plot CDFs
    plot_sample_cdfs(gmm_samples, normal_samples)
    
    # Compare CVaR convergence computation methods
    print(f"\nComparing CVaR convergence computation methods...")
    
    # Generate the plot using the parallel version (faster)
    gmm_cvars, normal_cvars, normal_lower_bounds, normal_upper_bounds = plot_cvar_convergence_with_bounds_parallel(
        gmm_samples=gmm_samples, 
        normal_samples=normal_samples, 
        alpha=alpha, 
        delta=delta, 
        eps=eps_distributional, 
        sample_sizes=None, 
        n_jobs=n_jobs
    )
    
    # For performance comparison, compute the same results sequentially without plotting
    print("Computing CVaR convergence sequentially for performance comparison...")
    start_time = time.time()
    
    # Dynamic range based on available samples
    max_samples = min(len(gmm_samples), len(normal_samples))
    sample_sizes = np.arange(100, max_samples + 1, max(1, max_samples // 20))  # ~20 points
    
    # Get bounds for the Normal distribution
    y_inf = np.min(normal_samples)
    y_sup = np.max(normal_samples)
    x_inf = y_inf  # Assuming same bounds for x and y
    x_sup = y_sup
    
    gmm_cvars_seq = []
    normal_cvars_seq = []
    normal_lower_bounds_seq = []
    normal_upper_bounds_seq = []
    
    for n in sample_sizes:
        # Use first n samples from each distribution
        gmm_cvar = cvar_estimator(gmm_samples[:n], alpha=alpha)
        normal_cvar = cvar_estimator(normal_samples[:n], alpha=alpha)
        
        # Calculate bounds for Normal distribution using stochastic bounds
        lower_bound, upper_bound = stochastic_cvar_bound_const_eps(
            y_samp=normal_samples[:n], 
            y_sup=y_sup, 
            y_inf=y_inf, 
            x_sup=x_sup, 
            x_inf=x_inf, 
            eps=eps_distributional, 
            delta=delta,
            alpha=alpha
        )
        
        gmm_cvars_seq.append(gmm_cvar)
        normal_cvars_seq.append(normal_cvar)
        normal_lower_bounds_seq.append(lower_bound)
        normal_upper_bounds_seq.append(upper_bound)
    
    sequential_time = time.time() - start_time
    parallel_time = 0.0  # The parallel version already ran above
    
    print(f"Sequential execution time: {sequential_time:.4f} seconds")
    print(f"Parallel execution time: {parallel_time:.4f} seconds (already completed)")
    print(f"Note: Parallel version was used to generate the plot")
    
    # Verify results are the same
    gmm_diff = np.max(np.abs(np.array(gmm_cvars_seq) - np.array(gmm_cvars)))
    normal_diff = np.max(np.abs(np.array(normal_cvars_seq) - np.array(normal_cvars)))
    print(f"Maximum difference in GMM CVaR results: {gmm_diff:.10f}")
    print(f"Maximum difference in Normal CVaR results: {normal_diff:.10f}")
    
    # Plot sampling time comparison across different sample sizes
    print(f"\nGenerating sampling time comparison plot...")
    gmm_sampling_args = {
        'means': means,
        'variances': variances, 
        'weights': weights,
        'lower': lower,
        'upper': upper
    }
    normal_sampling_args = {
        'mean': approx_mean,
        'std': approx_std,
        'lower': lower,
        'upper': upper
    }
    
    gmm_times, gmm_stds, normal_times, normal_stds, timing_sample_sizes = plot_sampling_time_comparison(
        gmm_sampling_func=sample_truncated_gmm_1d, 
        normal_sampling_func=sample_truncated_normal_approximation,
        gmm_args=gmm_sampling_args,
        normal_args=normal_sampling_args,
        sample_sizes=None,  # will use default range
        n_runs=n_runs,
        n_jobs=n_jobs
    )
    
    # Print summary statistics
    print(f"\nSampling Time Comparison Summary:")
    print(f"Sample sizes tested: {len(timing_sample_sizes)} points from {timing_sample_sizes[0]} to {timing_sample_sizes[-1]}")
    print(f"Average GMM sampling time: {np.mean(gmm_times):.6f} ± {np.mean(gmm_stds):.6f} seconds")
    print(f"Average Normal sampling time: {np.mean(normal_times):.6f} ± {np.mean(normal_stds):.6f} seconds")
    print(f"Average speedup (GMM/Normal): {np.mean([g/n for g, n in zip(gmm_times, normal_times)]):.2f}x")
    
    # Plot CVaR computation time comparison (sampling + computation)
    print(f"\nGenerating CVaR computation time comparison plot...")
    gmm_cvar_times, gmm_cvar_stds, normal_cvar_times, normal_cvar_stds, cvar_timing_sample_sizes = plot_cvar_computation_time_comparison(
        gmm_sampling_func=sample_truncated_gmm_1d, 
        normal_sampling_func=sample_truncated_normal_approximation,
        gmm_args=gmm_sampling_args,
        normal_args=normal_sampling_args,
        sample_sizes=None,  # will use default range
        alpha=alpha,
        delta=delta,
        eps=eps_distributional,
        n_runs=n_runs,
        n_jobs=n_jobs
    )
    
    # Print CVaR computation time summary statistics
    print(f"\nCVaR Computation Time Comparison Summary:")
    print(f"Sample sizes tested: {len(cvar_timing_sample_sizes)} points from {cvar_timing_sample_sizes[0]} to {cvar_timing_sample_sizes[-1]}")
    print(f"Average GMM CVaR computation time: {np.mean(gmm_cvar_times):.6f} ± {np.mean(gmm_cvar_stds):.6f} seconds")
    print(f"Average Normal bounds computation time: {np.mean(normal_cvar_times):.6f} ± {np.mean(normal_cvar_stds):.6f} seconds")
    print(f"Average speedup (GMM CVaR / Normal bounds): {np.mean([g/n for g, n in zip(gmm_cvar_times, normal_cvar_times)]):.2f}x")
    
    # List all saved plots
    list_saved_plots()


