import os
import sys
from util import *
from GADPQuantile import compute_radius,GADPQuantile_kchange_log
import numpy as np
import ray
from tqdm import tqdm
import time

def save_hdf5(var, file_path):
    """
    Save variables to an HDF5 file.

    Parameters
    ----------
    var : dict
        Dictionary containing variables to be saved.
    file_path : str
        Path to the HDF5 file.
    """
    import h5py
    import json
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with h5py.File(file_path, 'w') as f:
        for key, value in var.items():
            if isinstance(value, np.ndarray):
                # Save NumPy array directly as dataset
                f.create_dataset(key, data=value, compression="gzip")
            elif isinstance(value, list):
                # Convert list to NumPy array and save
                try:
                    arr_value = np.array(value)
                    f.create_dataset(key, data=arr_value, compression="gzip")
                except:
                    # If conversion fails, serialize as JSON string and save
                    json_str = json.dumps(value)
                    dt = h5py.special_dtype(vlen=str)
                    dset = f.create_dataset(key, (1,), dtype=dt)
                    dset[0] = json_str
            else:
                # Try to save other simple types as attributes
                try:
                    f.attrs[key] = value
                except:
                    # If attribute is too large, convert to string and save as dataset
                    str_value = str(value)
                    dt = h5py.special_dtype(vlen=str)
                    dset = f.create_dataset(key, (1,), dtype=dt)
                    dset[0] = str_value
    
def train(seed=None, dist_type='normal', tau=0.5, r=0.5,
          n_samples=1000, alpha=0.05, c0=2, a=0.51,
          burn_in_ratio=0.01, radius_typ='gm'):
    """
    Complete training process: generate data -> train model.

    Parameters
    ----------
    seed : int or None
        Random seed for reproducibility.
    dist_type : str
        The type of distribution ('normal', 'cauchy', 'laplace', etc.).
    tau : float
        The quantile to compute.
    r : float
        Response rate for the estimator.
    n_samples : int
        Number of samples to generate.
    alpha : float
        Significance level for confidence interval.
    c0 : float
        Hyperparameter about the learning rate for the estimator.
    a : float
        Hyperparameter about the learning rate for the estimator.
    b : float
        Hyperparameter about the learning rate for the estimator.
    burn_in_ratio : float
        Hyperparameter about ratio of burn-in samples among total n_samples.
    radius_typ : str
        Type of radius calculation.

    Returns
    -------
    dict
        Dictionary containing true quantile, point estimates, variances, and radius.
    """
    # Generate data
    np.random.seed(seed)
    true_q = generate_true_q(dist_type, tau)
    # Initialize multi-chain estimator
    model = GADPQuantile_kchange_log(r=r, tau=tau,
                         true_q=true_q, alpha=alpha, c0=c0, a=a, 
                         burn_in_ratio=burn_in_ratio, radius_typ=radius_typ,n_samples = n_samples)

    # Train the model
    model.fit(dist_type, tau, n_samples)
    point_estimates, var_estimates = model.global_means, model.global_vars
    # Convert to float16 to reduce memory usage
    point_estimates = np.array(point_estimates).astype(np.float16)
    var_estimates = np.array(var_estimates).astype(np.float16)
    
    return {
        'true_q': true_q,
        'estimates': point_estimates,
        'variances': var_estimates
    }

ray.init(
    runtime_env={
        "working_dir": ".",
        "excludes": [
            "*.snap",         
            "*.pkl",          
            "*.rar",           
            "core.*",           
            "core",  
            "*.h5",
            "*.npy"
        ]
    }
)

@ray.remote
def train_remote(seed=None, dist_type='normal', tau=0.5, r=0.5,
                 n_samples=1000, alpha=0.05, c0=2, a=0.51,
                 burn_in_ratio=0.01, radius_typ='gm'):

    """
    Ray remote wrapper for train.

    Parameters
    ----------
    seed : int or None
        Random seed for reproducibility.
    dist_type : str
        The type of distribution ('normal', 'cauchy', 'laplace', etc.).
    tau : float
        The quantile to compute.
    r : float
        Response rate for the estimator.
    K : int
        Number of chains or estimators.
    n_samples : int
        Number of samples to generate.
    alpha : float
        Significance level for confidence interval.
    c0 : float
        Hyperparameter about the learning rate for the estimator.
    a : float
        Hyperparameter about the learning rate for the estimator.
    b : float
        Hyperparameter about the learning rate for the estimator.
    burn_in_ratio : float
        Hyperparameter about ratio of burn-in samples among total n_samples.
    radius_typ : str
        Type of radius calculation.

    Returns
    -------
    dict
        Dictionary containing true quantile, point estimates and variances.
    """
    return train(seed=None, dist_type=dist_type, tau=tau, r=r,
          n_samples=n_samples, alpha=alpha, c0=c0, a=a,
          burn_in_ratio=burn_in_ratio, radius_typ=radius_typ)


def run_simulation(dist_type='normal', tau=0.5, r=0.5, 
                   n_samples=1000, n_simu=100, base_seed=2025, alpha=0.05,
                   c0=2, a=0.51, 
                   burn_in_ratio=0.01, radius_typ='gm'):
    """
    Main execution function for running multiple simulations in parallel and collecting results.

    Parameters
    ----------
    dist_type : str
        The type of distribution ('normal', 'cauchy', 'laplace', etc.).
    tau : float
        The quantile to compute.
    r : float
        Response rate for the estimator.
    n_samples : int
        Number of samples to generate.
    n_simu : int
        Number of simulations to run.
    base_seed : int
        Base random seed.
    alpha : float
        Significance level for confidence interval.
    c0 : float
        Hyperparameter about the learning rate for the estimator.
    a : float
        Hyperparameter about the learning rate for the estimator.
    b : float
        Hyperparameter about the learning rate for the estimator.
    burn_in_ratio : float
        Hyperparameter about ratio of burn-in samples among total n_samples.
    radius_typ : str
        Type of radius calculation.

    Returns
    -------
    dict
        Dictionary containing true quantile, point estimates, variances and radius.
    """
    # Submit parallel tasks
    print(f'n_samples:{n_samples}')
    futures = [train_remote.remote(seed=base_seed + i, dist_type=dist_type,
                                   tau=tau, r=r,
                                   n_samples=n_samples, alpha=alpha, c0=c0, a=a,
                                   burn_in_ratio=burn_in_ratio, radius_typ=radius_typ) for i in range(n_simu)]

    # Collect results
    # Original code: results = ray.get(futures)
    # Use tqdm to track the progress of result collection
    results = []
    pbar = tqdm(total=n_simu, desc=f"Processing r={r}, tau={tau}, radius_typ={radius_typ}, dist_type={dist_type}")
    while futures:
        done, futures = ray.wait(futures)
        results.extend(ray.get(done))
        pbar.update(len(done))
    pbar.close()

    # Calculate statistics
    true_q = results[0]['true_q']
    points = np.array([r['estimates'] for r in results])
    variances = np.array([r['variances'] for r in results])
    radius = compute_radius(np.arange(variances.shape[1])+1, rho = 0.001, m = 1, alpha=0.05, typ=radius_typ)

    return {
        'true_q': true_q,
        'estimates': points,
        'variances': variances,
        'radius': radius
    }

seed = 2025
n_samples = 5000000
n_sim = 2000
dist_types = ['normal','cauchy','laplace']
burn_in_ratio = 4

import time
import os
import pickle
radius_typs = ['ub','gm']
r_ = [1,0.9,0.75,0.5,0.25] 
a = 0.6
tau_ = [0.8,0.5,0.3]
c0 = 1
ct = 0
for tau in tau_:
    for r in r_:
        for dist_type in dist_types:
            for radius_typ in radius_typs:                
                print(f'\r dist_type:{dist_type} radius_typ:{radius_typ} r:{r} tau:{tau} :')
                t1 = time.time()
                output = run_simulation(dist_type=dist_type, tau=tau,
                                        r=r, n_samples=n_samples,
                                        n_simu=n_sim, base_seed=seed,alpha=0.05,
                                        c0=c0,a=a,burn_in_ratio=burn_in_ratio,radius_typ=radius_typ)
                t2 = time.time()
                ct+=1
                save_hdf5(output, f'./output_ga/{dist_type}/Kchange_radius_typ_{radius_typ}_n_samples_{n_samples}_r_{r}_tau_{tau}_a_{a}_c_{c0}_burnin_ratio_new_{burn_in_ratio}_n_sim_{n_sim}.h5')
                print(f'Saved in ./output_ga/{dist_type}/Kchange_radius_typ_{radius_typ}_n_samples_{n_samples}_r_{r}_tau_{tau}_a_{a}_c_{c0}_burnin_ratio_new_{burn_in_ratio}_n_sim_{n_sim}.h5')
                del output
                print(f"Elapsed time for this run: {t2 - t1:.2f} seconds")