import numpy as np
from tqdm import tqdm
from sortedcontainers import SortedList
import time
from scipy.stats import norm, cauchy,uniform,laplace
import ray

def generate_data(dist_type, tau , n_samples, seed = 2025):
    """
    Generate a data stream of the specified distribution and return (data, true quantile).

    Parameters
    ----------
    dist_type : str
        The type of distribution ('normal', 'uniform', 'cauchy', 'laplace').
    tau : float
        The quantile to compute.
    n_samples : int
        Number of samples to generate.
    seed : int, optional
        Random seed for reproducibility.

    Returns
    -------
    data : np.ndarray
        Generated data samples.
    true_q : float
        Theoretical quantile value for the given distribution and tau.
    """
    np.random.seed(seed)
    if dist_type == 'normal':
        data = np.random.normal(0, 1, n_samples)
        true_q = norm.ppf(tau)
    elif dist_type == 'uniform':
        data = np.random.uniform(-1, 1, n_samples)
        true_q = -1 + 2 * tau 
    elif dist_type == 'cauchy':
        data = np.random.standard_cauchy(n_samples)
        true_q = cauchy.ppf(tau)
    elif dist_type == 'laplace':
        data = np.random.laplace(0, 1, n_samples)
        true_q = laplace.ppf(tau)
    else:
        raise ValueError("Unsupported distribution type")
    return data, true_q

def compute_confidence_bounds_sortedlist_single(dist_type, tau, n_samples,seed = 2025, alpha=0.05, interval=1):
    """
    Use SortedList to calculate the confidence interval for a single simulation.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    tau : float
        The quantile to compute.
    n_samples : int
        Number of samples to generate.
    seed : int, optional
        Random seed for reproducibility.
    alpha : float, optional
        Significance level for confidence interval.
    interval : int, optional
        Interval for sampling points.

    Returns
    -------
    lu_list : np.ndarray
        Lower bound values at each interval.
    up_list : np.ndarray
        Upper bound values at each interval.
    interval : int
        The interval used for sampling points.
    """
    np.random.seed(seed)
    if n_samples < interval:
        print("Warning: n_samples < interval")
        return [], [], interval

    sample_points = list(range(interval, n_samples + 1, interval))
    n_points = len(sample_points)

    lu_list = []
    up_list = []

    # --- Data generation (only one simulation) ---
    all_data_single, _ = generate_data(dist_type, tau, n_samples)

    # --- Use SortedList ---
    sl = SortedList()
    current_data_idx = 0

    start_time = time.time()
    for i, Kt in tqdm(enumerate(sample_points), total=n_points, desc=f"SortedList {dist_type}"):
        # 1. Add new data
        new_data = all_data_single[current_data_idx : Kt]
        sl.update(new_data) # Efficient addition (O(k log N), k=interval)
        current_data_idx = Kt

        # 2. Calculate parameters and indices (strictly follow floor/ceil and 1-based -> 0-based)
        l_t = (1.4 * np.log(np.log(max(2.1, 2.1 * Kt))) + np.log(10 / alpha)) / Kt
        ft_p = 1.5 * np.sqrt(max(0, tau * (1 - tau)) * l_t) + 0.8 * l_t

        lower_val = Kt * (tau - ft_p)
        upper_val = Kt * (tau + ft_p)

        # Theoretical 1-based index k
        k_lower = int(np.floor(lower_val))
        k_upper = int(np.ceil(upper_val))

        # Convert to 0-based Python index (k-1)
        lower_idx = max(0, k_lower - 1)
        upper_idx = min(Kt - 1, k_upper - 1) # Kt-1 is the maximum valid index of current sl

        # 3. Query the k-th smallest element (accessing SortedList by index is O(log N))
        lu_val = sl[lower_idx]
        up_val = sl[upper_idx]

        lu_list.append(lu_val)
        up_list.append(up_val)
    end_time = time.time()

    return np.array(lu_list).astype(np.float16), np.array(up_list).astype(np.float16), interval

ray.init(
    runtime_env={
        "working_dir": ".", # Current directory
        "excludes": [
            "*.snap",       
            "*.pkl",           
            "*.rar",      
            "core.*",          
            "core", 
            "*.h5",
            "*.npy"
        ]
    }
)

@ray.remote
def compute_confidence_bounds_sortedlist_remote(dist_type, tau, n_samples, seed = 2025, alpha=0.05, interval=100): 
    """
    Ray remote wrapper for compute_confidence_bounds_sortedlist_single.

    Parameters
    ----------
    dist_type : str
    tau : float
    n_samples : int
    seed : int, optional
    alpha : float, optional
    interval : int, optional

    Returns
    -------
    tuple
        Output of compute_confidence_bounds_sortedlist_single.
    """
    return compute_confidence_bounds_sortedlist_single(dist_type, tau, n_samples, seed = seed, alpha=alpha, interval=interval)
    
def run_confidence_bounds_sortedlist_simulation(dist_type, tau, n_samples, n_simu, base_seed = 2025, alpha=0.05, interval=1):
    """
    Main execution function for running multiple simulations in parallel and collecting results.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    tau : float
        The quantile to compute.
    n_samples : int
        Number of samples to generate.
    n_simu : int
        Number of simulations to run.
    base_seed : int, optional
        Base random seed.
    alpha : float, optional
        Significance level for confidence interval.
    interval : int, optional
        Interval for sampling points.

    Returns
    -------
    dict
        Dictionary containing interval, lower bound list, and upper bound list.
    """
    # Submit parallel tasks
    futures = [compute_confidence_bounds_sortedlist_remote.remote(seed=base_seed + i, dist_type=dist_type,
                 tau=tau, n_samples=n_samples,alpha=alpha, interval = interval) for i in range(n_simu)]
    
    # Collect results
    # Original code: results = ray.get(futures)
    # Use tqdm to track the progress of result collection
    results = []
    pbar = tqdm(total=n_simu, desc=f"Processing tau={tau}")
    while futures:
        done, futures = ray.wait(futures)
        results.extend(ray.get(done))
        pbar.update(len(done))
    pbar.close()
    
    # Calculate statistics
    interval = results[0][2]  # Just take the first value
    lu_list = np.array([r[0] for r in results]).astype(np.float16)
    up_list = np.array([r[1] for r in results]).astype(np.float16)
    
    
    return {
        'interval':interval,
        'lu_list':lu_list,
        'up_list': up_list
    }
    
def save_simulation_results(dist_type, tau_list, n_samples, n_simu, base_seed, alpha, interval, output_dir):
    """
    Run simulations for a list of quantiles and save the results to files.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    tau_list : list of float
        List of quantiles to compute.
    n_samples : int
        Number of samples to generate.
    n_simu : int
        Number of simulations to run.
    base_seed : int
        Base random seed.
    alpha : float
        Significance level for confidence interval.
    interval : int
        Interval for sampling points.
    output_dir : str
        Directory to save the results.
    """
    os.makedirs(output_dir, exist_ok=True)
    for tau in tau_list:
        luup = run_confidence_bounds_sortedlist_simulation(
            dist_type=dist_type, tau=tau, n_samples=n_samples, n_simu=n_simu,
            base_seed=base_seed, alpha=alpha, interval=interval
        )
        bound = np.mean(luup['up_list'] - luup['lu_list'], axis=0)
        with open(f"{output_dir}/confidence_bounds_{dist_type}_tau{tau}_n500_ge1.pkl", 'wb') as f:
            pickle.dump(bound, f)
        with open(f"{output_dir}/confidence_luup_{dist_type}_tau{tau}_n500_ge1.pkl", 'wb') as f:
            pickle.dump(luup, f)
        del bound, luup
        print(f"Results saved in: {output_dir}")

if __name__ == "__main__":
    """
    Main script execution for running and saving simulation results for different distributions and quantiles.
    """
    tau_list = [0.8, 0.5, 0.3]
    n_samples = 5000000
    n_simu = 500
    base_seed = 2025
    alpha = 0.05
    interval = 1

    save_simulation_results('normal', tau_list, n_samples, n_simu, base_seed, alpha, interval, "./output_ga/normal")
    save_simulation_results('cauchy', tau_list, n_samples, n_simu, base_seed, alpha, interval, "./output_ga/cauchy")
    save_simulation_results('laplace', tau_list, n_samples, n_simu, base_seed, alpha, interval, "./output_ga/laplace")
    
    
    
    
    
    
    
    
