import h5py
import json
import pickle
import numpy as np
import math

def load_hdf5(file_path):
    """
    Load all elements from an HDF5 file and return a data dictionary.

    Parameters
    ----------
    file_path : str
        Path to the HDF5 file.

    Returns
    -------
    dict
        Dictionary containing all datasets and attributes from the HDF5 file.
    """
    result = {}
    with h5py.File(file_path, 'r') as f:
        # Load all datasets
        for key in f.keys():
            dataset = f[key]
            # If the dataset is of string type, it may be a JSON string
            if h5py.check_dtype(vlen=dataset.dtype) == str:
                try:
                    # Try to parse as JSON
                    json_data = json.loads(dataset[0])
                    result[key] = json_data
                except:
                    # If not JSON, store directly
                    result[key] = dataset[:]
            else:
                # For regular datasets, load directly into memory
                result[key] = dataset[:]
        # Load all attributes
        for key in f.attrs.keys():
            result[key] = f.attrs[key]
    return result

def load_data_fix(dist_type, n_samples, r, tau, K, a, c0, burn_in_ratio, n_sim=2000):
    """
    Load data for the specified dist_type and other parameters.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    n_samples : int
        Number of samples.
    r : float
        Response rate parameter r.
    tau : float
        Quantile parameter tau.
    K : int
        Number of chains.
    a : float
        Learning rate exponent.
    c0 : float
        Initial learning rate coefficient.
    burn_in_ratio : float
        Hyperparameters about ratio of burn-in samples.
    n_sim : int, optional
        Number of simulations (default is 2000).

    Returns
    -------
    dict or None
        Loaded data dictionary if successful, otherwise None.
    """
    filename = f'./output_ga/{dist_type}/Kfix_{dist_type}_n_samples_{n_samples}_r_{r}_tau_{tau}_K_{K}_a_{a}_c_{c0}_burnin_ratio_new_{burn_in_ratio}_n_sim_{n_sim}.h5'
    try:
        data = load_hdf5(filename)
        for key, value in data.items():
            if isinstance(value, np.ndarray):
                print(f"{key} shape: {value.shape}, type: {value.dtype}")
            else:
                print(f"{key}: {value}")
        return data
    except Exception as e:
        print(f"Error loading file {filename}: {e}")
        return None

def get_var_fix(dist_type, n_samples, r, tau, K, a, c0, burn_in_ratio, n_sim=2000):
    """
    Calculate and save the variance for the fixed-K setting.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    n_samples : int
        Number of samples.
    r : float
        Response rate parameter r.
    tau : float
        Quantile parameter tau.
    K : int
        Number of chains.
    a : float
        Learning rate exponent.
    c0 : float
        Initial learning rate coefficient.
    burn_in_ratio : float
        Ratio of burn-in samples.
    n_sim : int, optional
        Number of simulations (default is 2000).

    Returns
    -------
    np.ndarray or None
        The computed variance array if successful, otherwise None.
    """
    data = load_data_r_tau_a_c_burn_fix_new_laplace(method, n_samples, r, tau, K, a, c, burn_in_ratio, n_sim)
    if data is None:
        print(f"Failed to load data, r={r}, tau={tau}, a={a}")
        return None
    # Extract data
    var = data['variances']
    print(f'var:{var.shape[1]}')
    del data
    place = int(4799951 / 100) * K - 1
    step_old = np.arange(var.shape[1]) + n_samples + 1 - var.shape[1]
    step_new = np.arange(var.shape[1]) + 1
    variances = var[:, place] / step_old[place] * step_new[place]
    
    # Save file
    import os
    save_dir = f"./output_ga/{dist_type}/var"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    filename_var = f"{save_dir}/var_{r}_{tau}_{a}_{K}.npy"
    np.save(filename_var, variances)
    print(f'Saved file: {filename_var}')
    return variances

n_samples = 5000000
n_sim = 500
dist_types = ['normal','cauchy','laplace']
burn_in_ratio = 4
b = 0
r_ = [1,0.9,0.75,0.5,0.25]
a = 0.6
tau_ = [0.8,0.5,0.3]
c0 = 1
ct = 0
Ks = [20,40,80,100]
for K in Ks:
    for tau in tau_:
        for r in r_:
            for dist_type in dist_types:
                print(f'\r r:{r} tau:{tau} K:{K} dist_type:{dist_type}:')
                t1 = time.time()
                variances = get_var_fix(dist_type, n_samples, r, tau, K, a, c0, burn_in_ratio, n_sim)
                t2 = time.time()
                print(f"Elapsed time for this run: {t2 - t1:.2f} seconds")