import h5py
import json
import pickle
import numpy as np
import math
from scipy.stats import norm
from util import *
from DPQuantile import DPQuantile

def load_hdf5(file_path):
    """
    Load all elements from an HDF5 file and return a data dictionary.

    Parameters
    ----------
    file_path : str
        Path to the HDF5 file.

    Returns
    -------
    dict
        Dictionary containing all datasets and attributes from the HDF5 file.
    """
    result = {}
    with h5py.File(file_path, 'r') as f:
        # Load all datasets
        for key in f.keys():
            dataset = f[key]
            # If the dataset is of string type, it may be a JSON string
            if h5py.check_dtype(vlen=dataset.dtype) == str:
                try:
                    # Try to parse as JSON
                    json_data = json.loads(dataset[0])
                    result[key] = json_data
                except:
                    # If not JSON, store directly
                    result[key] = dataset[:]
            else:
                # For regular datasets, load directly into memory
                result[key] = dataset[:]
        # Load all attributes
        for key in f.attrs.keys():
            result[key] = f.attrs[key]
    return result

def load_data_kchange(dist_type, radius_typ, n_samples, r, tau, a, c0, burn_in_ratio, n_sim=2000):
    """
    Load data for the specified dist_type and other parameters.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    radius_typ : str
        The type of radius calculation.
    n_samples : int
        Number of samples.
    r : float
        Response rate parameter r.
    tau : float
        Quantile parameter tau.
    a : float
        Learning rate exponent.
    c0 : float
        Initial learning rate coefficient.
    burn_in_ratio : float
        Hyperparameters about ratio of burn-in samples.
    n_sim : int, optional
        Number of simulations (default is 2000).

    Returns
    -------
    dict or None
        Loaded data dictionary if successful, otherwise None.
    """
    filename =  f'./output_ga/{dist_type}/Kchange_radius_typ_{radius_typ}_n_samples_{n_samples}_r_{r}_tau_{tau}_a_{a}_c_{c0}_burnin_ratio_new_{burn_in_ratio}_n_sim_{n_sim}.h5'
    try:
        data = load_hdf5(filename)
        for key, value in data.items():
            if isinstance(value, np.ndarray):
                print(f"{key} shape: {value.shape}, type: {value.dtype}")
            else:
                print(f"{key}: {value}")
        return data
    except Exception as e:
        print(f"Error loading file {filename}: {e}")
        return None
    
def get_luup_single(dist_type, n_samples, r, tau, a, c0, burn_in_ratio, n_sim=2000):
    """
    Calculate and save lower and upper confidence sequence bounds for different methods (gm, ub, pt).

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    n_samples : int
        Number of samples.
    r : float
        Response rate parameter r.
    tau : float
        Quantile parameter tau.
    a : float
        Learning rate exponent.
    c0 : float
        Initial learning rate coefficient.
    burn_in_ratio : float
        Hyperparameters about ratio of burn-in samples.
    n_sim : int, optional
        Number of simulations (default is 2000).

    Returns
    -------
    tuple
        (lower_gm, upper_gm, lower_ub, upper_ub, lower_pt, upper_pt): arrays of lower and upper bounds for each method.
    """
    data_gm = load_data_kchange(dist_type, 'gm', n_samples, r, tau, a, c0, burn_in_ratio, n_sim=2000)
    if data_gm is None:
        print(f"Failed to load data, r={r}, tau={tau}, a={a}")
        return None
    # Extract data
    i = 0
    estimates = data_gm['estimates'][i,:].astype(np.float64)
    variances = data_gm['variances'][i,:].astype(np.float64)
    radius = data_gm['radius'][i,:].astype(np.float64)
    true_q = data_gm['true_q']
    del data_gm
    K_all = int(np.log10(n_samples)*8)              # Number of chains
    Knum_cur = int(np.log10(n_samples/5)*8)         # Current number of chains, start recording at 1/5 position
    incr_k_with_t = {n_samples:K_all+1}             # Record nodes where K changes

    # Build the dictionary using a for loop
    for i in range(Knum_cur+1, K_all+1):
        # 1. Calculate the original floating-point value
        original_key_float = 10**(i/8)
        start_key = math.ceil(original_key_float)

        # Find the smallest integer greater than or equal to start_key that is divisible by i
        remainder = start_key % (i-1)
        if remainder == 0:
            new_key = start_key
        else:
            new_key = start_key + (i - remainder)
        incr_k_with_t[int(new_key)] = i
    
    step = np.arange(variances.shape[1]) + 1
    
    # Assume step and incr_k_with_t are already defined
    step_k = np.zeros_like(step)
    sorted_keys = sorted(incr_k_with_t.keys())
    prev_key = 0
    for i, key in enumerate(sorted_keys):
        k_val = incr_k_with_t[key]-1
        # Interval: (prev_key, key]
        mask = (step > prev_key) & (step <= key)
        step_k[mask] = k_val
        prev_key = key
        
    tmp_length = 2 * np.sqrt(variances)  / np.sqrt(step_k)
    
    length_gm = tmp_length * radius
    length_pt = tmp_length * norm.ppf(1-0.05/2)
    del tmp_length, radius, variances # Free memory

    lower_gm = estimates - length_gm / 2
    upper_gm = estimates + length_gm / 2
    lower_pt = estimates - length_pt / 2
    upper_pt = estimates + length_pt / 2
    print(estimates[-1])
    del estimates, length_gm, length_pt # Free memory
    
    data_ub = load_data_kchange(dist_type, 'ub', n_samples, r, tau, a, c0, burn_in_ratio, n_sim=2000)
    if data_ub is None:
        print(f"Failed to load data, r={r}, tau={tau}, a={a}")
        return None
    # Extract data
    i = 0
    estimates = data_gm['estimates'][i,:].astype(np.float64)
    variances = data_gm['variances'][i,:].astype(np.float64)
    radius = data_gm['radius'][i,:].astype(np.float64)
    true_q = data_gm['true_q']
    del data_ub
    tmp_length = 2 * np.sqrt(variances)  / np.sqrt(step_k)
    
    length_ub = tmp_length * radius
    del tmp_length, radius, variances # Free memory

    lower_ub = estimates - length_ub / 2
    upper_ub = estimates + length_ub / 2
    del estimates, length_ub # Free memory
    
    # Save files
    import os
    save_dir = f"./output_ga/{dist_type}/single_CS"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    filename_lower_gm = f"{save_dir}/lower_gm_{r}_{tau}_{a}.npy"
    filename_lower_ub = f"{save_dir}/lower_ub_{r}_{tau}_{a}.npy"
    filename_lower_pt = f"{save_dir}/lower_pt_{r}_{tau}_{a}.npy"
    filename_upper_gm = f"{save_dir}/upper_gm_{r}_{tau}_{a}.npy"
    filename_upper_ub = f"{save_dir}/upper_ub_{r}_{tau}_{a}.npy"
    filename_upper_pt = f"{save_dir}/upper_pt_{r}_{tau}_{a}.npy"
    np.save(filename_lower_gm, lower_gm)
    np.save(filename_lower_ub, lower_ub)
    np.save(filename_lower_pt, lower_pt)
    np.save(filename_upper_gm, upper_gm)
    np.save(filename_upper_ub, upper_ub)
    np.save(filename_upper_pt, upper_pt)
    print(f'Files saved successfully')
    return lower_gm, upper_gm, lower_ub, upper_ub, lower_pt, upper_pt

def get_luup_single_SN(dist_type, n_samples, r, tau, a, c0, burn_in_ratio, n_sim=2000):
    """
    Calculate and save lower and upper confidence sequence bounds for the SN method.

    Parameters
    ----------
    dist_type : str
        The type of distribution.
    n_samples : int
        Number of samples.
    r : float
        Response rate parameter r.
    tau : float
        Quantile parameter tau.
    a : float
        Learning rate exponent.
    c0 : float
        Initial learning rate coefficient.
    burn_in_ratio : float
        Hyperparameters about ratio of burn-in samples.
    n_sim : int, optional
        Number of simulations (default is 2000).

    Returns
    -------
    tuple
        (lower_sn, upper_sn): arrays of lower and upper bounds for the SN method.
    """
    SN = DPQuantile(tau=0.8, r=1, true_q=None, burn_in_ratio=(r**2 / (100 * burn_in_ratio)))
    data_stream = generate_data(dist_type, tau, n_samples)
    SN.fit(np.array(data_stream[0]))
    upper_sn = SN.Q_avg + np.sqrt(SN.var) * 6.74735
    lower_sn = SN.Q_avg - np.sqrt(SN.var) * 6.74735

    import os
    save_dir = "./output_ga/{dist_type}/single_CS"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    filename_lower_SN = f"{save_dir}/lower_sn_{r}_{tau}_{a}.npy"
    filename_upper_SN = f"{save_dir}/upper_sn_{r}_{tau}_{a}.npy"
    np.save(filename_lower_SN, lower_sn)
    np.save(filename_upper_SN, upper_sn)
    print(f'Files saved successfully')
    return lower_sn, upper_sn

n_samples = 5000000
n_sim = 2000
dist_types = ['normal','cauchy','laplace']
burn_in_ratio = 4
b = 0
radius_typs = ['ub','gm']
r_ = [1,0.9,0.75]  #,0.5,0.25
a = 0.6
tau_ = [0.8] #,0.5,0.3
c0 = 1
ct = 0
for tau in tau_:
    for r in r_:
        for dist_type in dist_types:            
            print(f'\r dist_type:{dist_type} r:{r} tau:{tau} :')
            t1 = time.time()
            get_luup_single(dist_type, n_samples, r, tau, a, c0, burn_in_ratio, n_sim)
            get_luup_single_SN(dist_type, n_samples, r, tau, a, c0, burn_in_ratio, n_sim)
            t2 = time.time()
            print(f"Elapsed time for this run: {t2 - t1:.2f} seconds")