import h5py
import json
import pickle
import numpy as np
import math

def load_hdf5(file_path):
    """
    Load all elements from an HDF5 file and return a data dictionary.

    Parameters
    ----------
    file_path : str
        Path to the HDF5 file.

    Returns
    -------
    dict
        Dictionary containing all datasets and attributes from the HDF5 file.
    """
    result = {}
    with h5py.File(file_path, 'r') as f:
        # Load all datasets
        for key in f.keys():
            dataset = f[key]
            # If the dataset is of string type, it may be a JSON string
            if h5py.check_dtype(vlen=dataset.dtype) == str:
                try:
                    # Try to parse as JSON
                    json_data = json.loads(dataset[0])
                    result[key] = json_data
                except:
                    # If not JSON, store directly
                    result[key] = dataset[:]
            else:
                # For regular datasets, load directly into memory
                result[key] = dataset[:]
        # Load all attributes
        for key in f.attrs.keys():
            result[key] = f.attrs[key]
    return result

def load_data_kchange(dist_type, radius_typ, n_samples, r, tau, a, c0, burn_in_ratio, n_sim=2000):
    """
    Load data for the specified dist_type and other parameters.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    radius_typ : str
        The type of radius calculation.
    n_samples : int
        Number of samples.
    r : float
        Response rate parameter r.
    tau : float
        Quantile parameter tau.
    a : float
        Learning rate exponent.
    c0 : float
        Initial learning rate coefficient.
    burn_in_ratio : float
        Hyperparameters about ratio of burn-in samples.
    n_sim : int, optional
        Number of simulations (default is 2000).

    Returns
    -------
    dict or None
        Loaded data dictionary if successful, otherwise None.
    """
    filename =  f'./output_ga/{dist_type}/Kchange_radius_typ_{radius_typ}_n_samples_{n_samples}_r_{r}_tau_{tau}_a_{a}_c_{c0}_burnin_ratio_new_{burn_in_ratio}_n_sim_{n_sim}.h5'
    try:
        data = load_hdf5(filename)
        for key, value in data.items():
            if isinstance(value, np.ndarray):
                print(f"{key} shape: {value.shape}, type: {value.dtype}")
            else:
                print(f"{key}: {value}")
        return data
    except Exception as e:
        print(f"Error loading file {filename}: {e}")
        return None

  
def get_typeIerror_meanlength(dist_type, radius_typ, n_samples, r, tau, a, c0, burn_in_ratio, n_sim=2000):
    """
    Calculate type I error and mean confidence interval length for the given parameters.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    radius_typ : str
        The type of radius calculation.
    n_samples : int
        Number of samples.
    r : float
        Response rate parameter r.
    tau : float
        Quantile parameter tau.
    a : float
        Learning rate exponent.
    c0 : float
        Initial learning rate coefficient.
    burn_in_ratio : float
        Hyperparameters about ratio of burn-in samples.
    n_sim : int, optional
        Number of simulations (default is 2000).

    Returns
    -------
    tuple
        (errors, mean_length): errors is the type I error array, mean_length is the mean interval length array.
    """
    import math
    data = load_data_kchange(dist_type, radius_typ, n_samples, r, tau, a, c0, burn_in_ratio, n_sim)
    if data is None:
        print(f"Failed to load data, dist_type = {dist_type}, radius_typ = {radius_typ}, n_samples={n_samples}, r={r}, tau={tau}, a={a}, n_sim={n_sim}")
        return None
    # Extract data
    estimates = data['estimates'].astype(np.float32)
    variances = data['variances'].astype(np.float32)
    radius = data['radius'].astype(np.float32)
    true_q = data['true_q']
    
    K_all = int(np.log10(n_samples)*8)              # Number of chains
    Knum_cur = int(np.log10(n_samples/5)*8)         # Current number of chains, start recording at 1/5 position
    incr_k_with_t = {n_samples:K_all+1}             # Record nodes where K changes

    # Build the dictionary using a for loop
    for i in range(Knum_cur+1, K_all+1):
        # 1. Calculate the original floating-point value
        original_key_float = 10**(i/8)
        start_key = math.ceil(original_key_float)

        # Find the smallest integer greater than or equal to start_key that is divisible by i
        remainder = start_key % (i-1)
        if remainder == 0:
            new_key = start_key
        else:
            new_key = start_key + (i - remainder)
        incr_k_with_t[int(new_key)] = i
    
    step = np.arange(variances.shape[1]) + 1
    
    # Assume step and incr_k_with_t are already defined
    step_k = np.zeros_like(step)
    sorted_keys = sorted(incr_k_with_t.keys())
    prev_key = 0
    for i, key in enumerate(sorted_keys):
        k_val = incr_k_with_t[key]-1
        # Interval: (prev_key, key]
        mask = (step > prev_key) & (step <= key)
        step_k[mask] = k_val
        prev_key = key
    
    tmp_length = 2 * np.sqrt(variances)  / np.sqrt(step_k)
    length = tmp_length * radius
    
    mean_length = np.mean(length,axis=0)
    del radius,variances,data # Free memory

    lower = estimates - length / 2
    upper = estimates + length / 2
    covered = (true_q >= lower) & (true_q <= upper)
    del lower, upper,  estimates,length # Free memory
    cumulative_covered = np.array([np.logical_and.accumulate(cov) for cov in covered])
    del covered # Free memory
    errors = 1 - cumulative_covered.mean(axis=0)
    del cumulative_covered
    print(f'Errors: {errors}')
    
    # Save files
    import os
    save_dir = f"./output_ga/{dist_type}/typeI_error_length_{radius_typ}"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    filename_mean = f"{save_dir}/mean_length_{r}_{tau}_{a}.npy"
    filename_errors = f"{save_dir}/errors_{r}_{tau}_{a}.npy"
    np.save(filename_mean, mean_length)
    np.save(filename_errors, errors)
    print(f'Saved file: {filename_mean}')
    print(f'Saved file: {filename_errors}')
    return errors,mean_length

n_samples = 5000000
n_sim = 2000
dist_types = ['normal','cauchy','laplace']
burn_in_ratio = 4
b = 0
radius_typs = ['ub','gm']
r_ = [1,0.9,0.75,0.5,0.25]
a = 0.6
tau_ = [0.8,0.5,0.3]
c0 = 1
ct = 0
for tau in tau_:
    for r in r_:
        for dist_type in dist_types:
            for radius_typ in radius_typs:                
                print(f'\r dist_type:{dist_type} radius_typ:{radius_typ} r:{r} tau:{tau} :')
                t1 = time.time()
                get_typeIerror_meanlength(dist_type, radius_typ, n_samples, r, tau, a, c0, burn_in_ratio, n_sim)
                t2 = time.time()
                print(f"Elapsed time for this run: {t2 - t1:.2f} seconds")