import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent))
from bounds import cvar_estimator, cvar_probabilistic_lower_bound_thomas, cvar_probabilistic_upper_bound_thomas, cvar_interpretable_concentration_inequality

def generate_truncated_normal_samples(n_samples: int, lower_bound: float = -1, upper_bound: float = 1) -> np.ndarray:
    """
    Generate samples from a truncated standard normal distribution.
    
    Args:
        n_samples: Number of samples to generate
        lower_bound: Lower truncation bound
        upper_bound: Upper truncation bound
        
    Returns:
        np.ndarray: Array of truncated normal samples
    """
    # Create truncated normal distribution
    # For standard normal truncated to [a, b], we need to specify the bounds in terms of standard deviations
    # Since we're truncating a standard normal (mean=0, std=1), the bounds are already in the right units
    truncated_normal = stats.truncnorm(lower_bound, upper_bound, loc=0, scale=1)
    
    # Generate samples
    samples = truncated_normal.rvs(size=n_samples)
    
    return samples

def generate_beta_samples(n_samples: int, alpha: float, beta: float) -> np.ndarray:
    """
    Generate samples from a Beta distribution.
    
    Args:
        n_samples: Number of samples to generate
        alpha: Alpha parameter of Beta distribution
        beta: Beta parameter of Beta distribution
        
    Returns:
        np.ndarray: Array of Beta samples
    """
    beta_dist = stats.beta(alpha, beta)
    return beta_dist.rvs(size=n_samples)

def generate_laplace_samples(n_samples: int, loc: float = 0, scale: float = 1) -> np.ndarray:
    """
    Generate samples from a Laplace distribution.
    
    Args:
        n_samples: Number of samples to generate
        loc: Location parameter
        scale: Scale parameter
        
    Returns:
        np.ndarray: Array of Laplace samples
    """
    laplace_dist = stats.laplace(loc=loc, scale=scale)
    return laplace_dist.rvs(size=n_samples)

def run_simulation_for_distribution(dist_name: str, sample_generator, dist_params: dict, 
                                  sample_sizes: np.ndarray, alpha: float, delta: float, 
                                  n_iterations: int, dist_bounds: tuple) -> dict:
    """
    Run simulation for a specific distribution.
    
    Args:
        dist_name: Name of the distribution
        sample_generator: Function to generate samples
        dist_params: Parameters for the distribution
        sample_sizes: Array of sample sizes to test
        alpha: CVaR confidence level
        delta: Probability for Thomas bounds
        n_iterations: Number of iterations to average over
        dist_bounds: (lower_bound, upper_bound) for the distribution
        
    Returns:
        dict: Results containing CVaR estimates and bounds
    """
    print(f"\nRunning simulation for {dist_name}...")
    
    # Generate large sample for true CVaR
    large_sample = sample_generator(10000, **dist_params)
    true_cvar = cvar_estimator(large_sample, alpha)
    print(f"True CVaR ({dist_name}, alpha={alpha}): {true_cvar:.4f}")
    
    cvar_estimates = []
    thomas_lower_bounds = []
    thomas_upper_bounds = []
    stochastic_lower_bounds = []
    stochastic_upper_bounds = []
    
    for n in sample_sizes:
        # Arrays to store results for this sample size
        cvar_results = []
        lower_bound_results = []
        upper_bound_results = []
        stochastic_lower_results = []
        stochastic_upper_results = []
        
        # Run multiple iterations for this sample size
        for i in range(n_iterations):
            # Generate samples
            samples = sample_generator(n, **dist_params)
            
            # Compute CVaR estimate
            cvar_est = cvar_estimator(samples, alpha)
            cvar_results.append(cvar_est)
            
            # Compute Thomas bounds
            thomas_lower = cvar_probabilistic_lower_bound_thomas(
                samples, alpha, delta, dist_bounds[0]
            )
            thomas_upper = cvar_probabilistic_upper_bound_thomas(
                samples, alpha, delta, dist_bounds[1]
            )
            
            # Compute stochastic bounds
            interpretable_bounds = cvar_interpretable_concentration_inequality(
                samples, dist_bounds[1], dist_bounds[0], delta, alpha
            )
            # The function returns a tuple (lower, upper) based on the implementation
            stochastic_lower, stochastic_upper = interpretable_bounds
            
            lower_bound_results.append(thomas_lower)
            upper_bound_results.append(thomas_upper)
            stochastic_lower_results.append(stochastic_lower)
            stochastic_upper_results.append(stochastic_upper)
        
        # Average the results
        avg_cvar = np.mean(cvar_results)
        avg_lower = np.mean(lower_bound_results)
        avg_upper = np.mean(upper_bound_results)
        avg_stochastic_lower = np.mean(stochastic_lower_results)
        avg_stochastic_upper = np.mean(stochastic_upper_results)
        
        cvar_estimates.append(avg_cvar)
        thomas_lower_bounds.append(avg_lower)
        thomas_upper_bounds.append(avg_upper)
        stochastic_lower_bounds.append(avg_stochastic_lower)
        stochastic_upper_bounds.append(avg_stochastic_upper)
        
        print(f"Sample size {n}: Avg CVaR={avg_cvar:.4f}, Avg Lower={avg_lower:.4f}, Avg Upper={avg_upper:.4f}, Stochastic Lower={avg_stochastic_lower:.4f}, Stochastic Upper={avg_stochastic_upper:.4f}")
    
    return {
        'name': dist_name,
        'true_cvar': true_cvar,
        'cvar_estimates': cvar_estimates,
        'lower_bounds': thomas_lower_bounds,
        'upper_bounds': thomas_upper_bounds,
        'stochastic_lower_bounds': stochastic_lower_bounds,
        'stochastic_upper_bounds': stochastic_upper_bounds
    }

def main():
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Parameters
    alpha = 0.05  # CVaR confidence level
    delta = 0.05  # Probability for Thomas bounds
    n_iterations = 30  # Number of iterations to average over
    sample_sizes = np.arange(100, 1001, 50)  # From 100 to 1000 in steps of 50
    
    # Define distributions to test
    distributions = [
        {
            'name': 'Truncated Normal(-1,1)',
            'generator': generate_truncated_normal_samples,
            'params': {'lower_bound': -1, 'upper_bound': 1},
            'bounds': (-1, 1)
        },
        {
            'name': 'Beta(2,2)',
            'generator': generate_beta_samples,
            'params': {'alpha': 2, 'beta': 2},
            'bounds': (0, 1)
        },
        {
            'name': 'Beta(0.5,0.5)',
            'generator': generate_beta_samples,
            'params': {'alpha': 0.5, 'beta': 0.5},
            'bounds': (0, 1)
        },
        {
            'name': 'Beta(2,5)',
            'generator': generate_beta_samples,
            'params': {'alpha': 2, 'beta': 5},
            'bounds': (0, 1)
        },
        {
            'name': 'Beta(5,2)',
            'generator': generate_beta_samples,
            'params': {'alpha': 5, 'beta': 2},
            'bounds': (0, 1)
        },
        {
            'name': 'Beta(10,2)',
            'generator': generate_beta_samples,
            'params': {'alpha': 10, 'beta': 2},
            'bounds': (0, 1)
        },
        {
            'name': 'Beta(2,10)',
            'generator': generate_beta_samples,
            'params': {'alpha': 2, 'beta': 10},
            'bounds': (0, 1)
        },
        {
            'name': 'Laplace(0,1)',
            'generator': generate_laplace_samples,
            'params': {'loc': 0, 'scale': 1},
            'bounds': (-10, 10)  # Approximate bounds for Laplace
        }
    ]
    
    # Run simulations for all distributions
    results = []
    for dist in distributions:
        result = run_simulation_for_distribution(
            dist['name'], dist['generator'], dist['params'], 
            sample_sizes, alpha, delta, n_iterations, dist['bounds']
        )
        results.append(result)
    
    # Create plots for each distribution
    n_distributions = len(distributions)
    fig, axes = plt.subplots(4, 2, figsize=(16, 24))
    axes = axes.flatten()
    
    for i, result in enumerate(results):
        ax = axes[i]
        
        # Plot CVaR estimates and bounds
        ax.plot(sample_sizes, result['cvar_estimates'], 'b-o', 
                label=f'CVaR Estimate (α={alpha})', linewidth=2, markersize=4)
        ax.plot(sample_sizes, result['lower_bounds'], 'r--s', 
                label=f'Thomas Lower Bound (δ={delta})', linewidth=2, markersize=4)
        ax.plot(sample_sizes, result['upper_bounds'], 'g--^', 
                label=f'Thomas Upper Bound (δ={delta})', linewidth=2, markersize=4)
        ax.plot(sample_sizes, result['stochastic_lower_bounds'], 'm--d', 
                label=f'Stochastic Lower Bound', linewidth=2, markersize=4)
        ax.plot(sample_sizes, result['stochastic_upper_bounds'], 'c--v', 
                label=f'Stochastic Upper Bound', linewidth=2, markersize=4)
        ax.axhline(y=result['true_cvar'], color='k', linestyle='-', alpha=0.7, 
                   label=f'True CVaR')
        
        # Customize subplot
        ax.set_xlabel('Sample Size', fontsize=10)
        ax.set_ylabel('CVaR Value', fontsize=10)
        ax.set_title(f'{result["name"]}\nα={alpha}, δ={delta}', fontsize=12)
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
        
        # Add statistics
        thomas_coverage_rate = np.mean((result['lower_bounds'] <= result['true_cvar']) & 
                                      (result['upper_bounds'] >= result['true_cvar']))
        stochastic_coverage_rate = np.mean((result['stochastic_lower_bounds'] <= result['true_cvar']) & 
                                          (result['stochastic_upper_bounds'] >= result['true_cvar']))
        stats_text = f'True CVaR: {result["true_cvar"]:.4f}\n'
        stats_text += f'Thomas Coverage: {thomas_coverage_rate:.2%}\n'
        stats_text += f'Stochastic Coverage: {stochastic_coverage_rate:.2%}'
        
        ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, 
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
                fontsize=8)
    
    plt.tight_layout()
    
    # Create directory if it doesn't exist
    output_dir = Path('./dist_disc_plots')
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save the plot as PNG
    output_file = output_dir / 'cvar_bounds_comparison.png'
    plt.savefig(str(output_file), format='png', dpi=300, bbox_inches='tight')
    print(f"Plot saved as '{output_file}'")
    plt.close()
    
    # Print summary statistics
    print("\n" + "="*100)
    print("SUMMARY STATISTICS")
    print("="*100)
    for result in results:
        thomas_coverage_rate = np.mean((result['lower_bounds'] <= result['true_cvar']) & 
                                      (result['upper_bounds'] >= result['true_cvar']))
        stochastic_coverage_rate = np.mean((result['stochastic_lower_bounds'] <= result['true_cvar']) & 
                                          (result['stochastic_upper_bounds'] >= result['true_cvar']))
        print(f"{result['name']:20s}: True CVaR = {result['true_cvar']:.4f}, Thomas Coverage = {thomas_coverage_rate:.2%}, Stochastic Coverage = {stochastic_coverage_rate:.2%}")

if __name__ == "__main__":
    main()
