import os
import sys
from util import *
from GADPQuantile import compute_radius,GADPQuantile_kfix
import numpy as np
import ray
from tqdm import tqdm
import time

def save_hdf5(var, file_path):
    """
    Save variables to an HDF5 file.

    Parameters
    ----------
    var : dict
        Dictionary containing variables to be saved.
    file_path : str
        Path to the HDF5 file.
    """
    import h5py
    import json
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with h5py.File(file_path, 'w') as f:
        for key, value in var.items():
            if isinstance(value, np.ndarray):
                # Save NumPy array directly as dataset
                f.create_dataset(key, data=value, compression="gzip")
            elif isinstance(value, list):
                # Convert list to NumPy array and save
                try:
                    arr_value = np.array(value)
                    f.create_dataset(key, data=arr_value, compression="gzip")
                except:
                    # If conversion fails, serialize as JSON string and save
                    json_str = json.dumps(value)
                    dt = h5py.special_dtype(vlen=str)
                    dset = f.create_dataset(key, (1,), dtype=dt)
                    dset[0] = json_str
            else:
                # Try to save other simple types as attributes
                try:
                    f.attrs[key] = value
                except:
                    # If attribute is too large, convert to string and save as dataset
                    str_value = str(value)
                    dt = h5py.special_dtype(vlen=str)
                    dset = f.create_dataset(key, (1,), dtype=dt)
                    dset[0] = str_value
                    
                    
def train(seed=None, dist_type='normal', tau=0.5, r=0.5, K=20,
          n_samples=1000, alpha=0.05, c0=2, a=0.51, b=0,
          burn_in_ratio=0.01):
    """
    Complete training process: generate data -> train model.

    Parameters
    ----------
    seed : int or None
        Random seed for reproducibility.
    dist_type : str
        The type of distribution ('normal', 'cauchy', 'laplace', etc.).
    tau : float
        The quantile to compute.
    r : float
        Response rate for the estimator.
    K : int
        Number of chains or estimators.
    n_samples : int
        Number of samples to generate.
    alpha : float
        Significance level for confidence interval.
    c0 : float
        Hyperparameter about the learning rate for the estimator.
    a : float
        Hyperparameter about the learning rate for the estimator.
    b : float
        Hyperparameter about the learning rate for the estimator.
    burn_in_ratio : float
        Hyperparameter about ratio of burn-in samples among total n_samples.

    Returns
    -------
    dict
        Dictionary containing true quantile, point estimates and variances.
    """
    # Generate data
    np.random.seed(seed)
    true_q = generate_true_q(dist_type, tau)
    # Initialize multi-chain estimator
    model = GADPQuantile_kfix(r=r, tau=tau, K=K,
                              true_q=true_q, alpha=alpha, c0=c0, a=a, b=b,
                              burn_in_ratio=burn_in_ratio, n_samples=n_samples)

    # Train the model
    model.fit(dist_type, tau, n_samples)
    point_estimates, var_estimates = model.global_means, model.global_vars
    # Convert to float16 to reduce memory usage
    point_estimates = np.array(point_estimates).astype(np.float16)
    var_estimates = np.array(var_estimates).astype(np.float16)
    
    return {
        'true_q': true_q,
        'estimates': point_estimates,
        'variances': var_estimates
    }

ray.init(
    runtime_env={
        "working_dir": ".", 
        "excludes": [
            "*.snap",   
            "*.pkl",         
            "*.rar",        
            "core.*",          
            "core",  
            "*.h5",
            "*.npy"
        ]
    }
)

@ray.remote
def train_remote(seed=None, dist_type='normal', tau=0.5, r=0.5,K = 20,
                 n_samples=1000, alpha=0.05, c0=2, a=0.51, b=0,
                 burn_in_ratio=0.01):
    """
    Ray remote wrapper for train.

    Parameters
    ----------
    seed : int or None
        Random seed for reproducibility.
    dist_type : str
        The type of distribution ('normal', 'cauchy', 'laplace', etc.).
    tau : float
        The quantile to compute.
    r : float
        Response rate for the estimator.
    K : int
        Number of chains or estimators.
    n_samples : int
        Number of samples to generate.
    alpha : float
        Significance level for confidence interval.
    c0 : float
        Hyperparameter about the learning rate for the estimator.
    a : float
        Hyperparameter about the learning rate for the estimator.
    b : float
        Hyperparameter about the learning rate for the estimator.
    burn_in_ratio : float
        Hyperparameter about ratio of burn-in samples among total n_samples.

    Returns
    -------
    dict
        Dictionary containing true quantile, point estimates and variances.
    """
    return train(seed=seed, dist_type=dist_type, tau=tau, r=r,K = K,
          n_samples=n_samples, alpha=alpha, c0=c0, a=a, b=b,
          burn_in_ratio=burn_in_ratio)

def run_simulation(dist_type='normal', tau=0.5, r=0.5, K=20,
                   n_samples=1000, n_simu=100, base_seed=2025, alpha=0.05,
                   c0=2, a=0.51, b=0,
                   burn_in_ratio=0.01):
    """
    Main execution function for running multiple simulations in parallel and collecting results.

    Parameters
    ----------
    dist_type : str
        The type of distribution ('normal', 'cauchy', 'laplace', etc.).
    tau : float
        The quantile to compute.
    r : float
        Response rate for the estimator.
    K : int
        Number of chains or estimators.
    n_samples : int
        Number of samples to generate.
    n_simu : int
        Number of simulations to run.
    base_seed : int
        Base random seed.
    alpha : float
        Significance level for confidence interval.
    c0 : float
        Hyperparameter about the learning rate for the estimator.
    a : float
        Hyperparameter about the learning rate for the estimator.
    b : float
        Hyperparameter about the learning rate for the estimator.
    burn_in_ratio : float
        Hyperparameter about ratio of burn-in samples among total n_samples.

    Returns
    -------
    dict
        Dictionary containing true quantile, point estimates, and variances.
    """
    # Submit parallel tasks
    print(f'n_samples:{n_samples}')
    futures = [train_remote.remote(seed=base_seed + i, dist_type=dist_type,
                                   tau=tau, r=r, K=K,
                                   n_samples=n_samples, alpha=alpha, c0=c0, a=a, b=b,
                                   burn_in_ratio=burn_in_ratio) for i in range(n_simu)]

    # Collect results
    # Original code: results = ray.get(futures)
    # Use tqdm to track the progress of result collection
    results = []
    pbar = tqdm(total=n_simu, desc=f"Processing r={r}, tau={tau}, K={K}")
    while futures:
        done, futures = ray.wait(futures)
        results.extend(ray.get(done))
        pbar.update(len(done))
    pbar.close()

    # Calculate statistics
    true_q = results[0]['true_q']
    points = np.array([r['estimates'] for r in results])
    variances = np.array([r['variances'] for r in results])

    return {
        'true_q': true_q,
        'estimates': points,
        'variances': variances
    }

seed = 2025
n_samples = 5000000
n_sim = 500
dist_types = ['normal','cauchy','laplace']
burn_in_ratio = 4
b = 0

import time
import os
import pickle
r_ = [1,0.9,0.75,0.5,0.25]
a = 0.6
tau_ = [0.8,0.5,0.3]
c0 = 1
ct = 0
Ks = [20,40,80,100]
for K in Ks:
    for tau in tau_:
        for r in r_:
            for dist_type in dist_types:
                print(f'\r r:{r} tau:{tau} K:{K} :')
                t1 = time.time()
                output = run_simulation(dist_type=dist_type, tau=tau,K = K,
                                        r=r, n_samples=n_samples,
                                        n_simu=n_sim, base_seed=seed,alpha=0.05,
                                        c0=c0,a=a,b=b,burn_in_ratio=burn_in_ratio) 
                t2 = time.time()
                ct+=1
                save_hdf5(output, f'./output_ga/{dist_type}/Kfix_{dist_type}_n_samples_{n_samples}_r_{r}_tau_{tau}_K_{K}_a_{a}_c_{c0}_burnin_ratio_new_{burn_in_ratio}_n_sim_{n_sim}.h5')
                del output
                print(f'Saved in ./output_ga/{dist_type}/Kfix_{dist_type}_n_samples_{n_samples}_r_{r}_tau_{tau}_K_{K}_a_{a}_c_{c0}_burnin_ratio_new_{burn_in_ratio}_n_sim_{n_sim}.h5')
                print(f"Elapsed time for this run: {t2 - t1:.2f} seconds")