import numpy as np
import sys
from util import *

def generate_federate_data(dist_type, tau, n_n_samples,
                                   biases): 
    """
    Generate one data stream per site plus its true τ-quantile.

    Parameters
    ----------
    dist_type   : {'normal', 'uniform', 'cauchy'}
        Distribution family for every site.
    tau         : float
        Target quantile level.
    n_n_samples : list[int]
        Sample size for each site (length = n_sites).
    biases      : list[float]
        Location shifts for the *source* machines.
        The target machine (index 0) is fixed at shift 0.

    Returns
    -------
    datas   : list[np.ndarray]
        Generated samples per site.
    true_qs : list[float]
        Corresponding population τ-quantiles.
    """
    n_sites = len(n_n_samples)
    datas = []
    true_qs = []
    for i in range(n_sites):
        data_i, true_q_i = generate_data(dist_type, tau, n_n_samples[i],
                                         mu=biases[i])
        datas.append(data_i)
        true_qs.append(true_q_i)
    return datas, true_qs


def get_prop_K(datas, K_base=1):
    """
    Allocate the number of chains to each site proportionally to its
    sample size, rounded to integers.
    """
    ns = np.array([len(d) for d in datas])
    return (K_base * ns / ns[0]).astype(int)

def get_n_n_sample(n_sites, n_sample_target, source_prop=1):
    """
    Return a list of per-site sample sizes:
        [ n_target , source_prop * n_target , … , source_prop * n_target ]
    """
    src_samples = int(n_sample_target * source_prop)
    return [n_sample_target] + [src_samples] * (n_sites - 1)

def get_true_var(dist_type='normal', mu=0, tau=0.5, r=1, true_q=0):
    """
    Population variance of the LDP-quantile estimator:

        Var = [ 1 − r²(2τ − 1)² ] / [ 4 r² f_X(Q)² ].
    """
    # 1) Density f_X(Q)
    if dist_type == 'normal':
        fx = norm.pdf(true_q, loc=mu, scale=1)        # N(μ,1)
    elif dist_type == 'uniform':
        low, high = mu - 1, mu + 1                      # U(μ−1, μ+1)
        fx = 1.0 / (high - low) if low <= true_q <= high else 0.0
    elif dist_type == 'cauchy':
        fx = cauchy.pdf(true_q, loc=mu, scale=1)        # Cauchy(μ,1)
    else:
        raise ValueError("Unsupported distribution type")

    # 2) Plug into the formula
    numerator   = 1.0 - r**2 * (2 * tau - 1)**2        # 1 - r²(2τ-1)²
    denominator = 4.0 * r**2 * fx**2
    var = numerator / denominator
    return var

def compute_oracle(estimates,true_qs,mus,
                   rs,tau=0.5,
                   lambd=0.1,dist_type='normal'):
    """
    Compute oracle weights and oracle-aggregated estimates.

    Returns
    -------
    weights          : ndarray, shape (K,)
        Oracle weights summing to 1.
    oracle_estimates : ndarray
        Aggregated estimate for each sites.
    """
    # Prepare vectors
    estimates, true_qs, rs = (np.asarray(a, float) for a in (estimates, true_qs, rs))
    ns = np.asarray(ns, int)
    K = true_qs.size
    true_vars = np.array([get_true_var(dist_type, mu, tau, r, q)
                              for mu, r, q in zip(mus, rs, true_qs)])
    # Biases relative to the target site (b₀ = 0)
    biases = true_qs - true_qs[0]            # b_0 = 0
    weights = compute_optimal_weights(biases, true_vars, lambd=lambd)
    # Oracle estimate (same weights for every simulation)
    oracle_est = np.dot(estimates, weights)

    return weights, oracle_estimates

def analyze_results(results,est_key='opt'):
    """
    Compute RMSE, empirical coverage, and average CI length for either
    the target-site estimate or the aggregated one.

    Parameters
    ----------
    results : dict
        Must contain keys '<est_key>_est', '<est_key>_var', and 'true_qs'.
    est_key : {'target', 'opt'}, default='opt'
        Which estimator to evaluate.

    Returns
    -------
    dict with keys 'rmse', 'coverage', 'CIlen'.
    """
    if est_key != 'target':
        z_score = 1.96
    elif est_key == 'target':
        z_score = 6.74735
    
    est_name = est_key + '_est'
    var_name = est_key + '_var'
    est = results[est_name]
    var = results[var_name]
    true_q = results['true_qs'][0]
    
    lower = est - z_score * np.sqrt(var)
    upper = est + z_score * np.sqrt(var)
    coverage = np.mean((true_q >= lower) & (true_q <= upper))
    
    return {
        'rmse': np.sqrt(np.mean((est-true_q)**2)),
        'coverage': coverage,
        'CIlen':np.mean(upper-lower)
    }



# --------------------------------------------------------------------------- #
# Classical (non-private) plug-in variance                                 #
# --------------------------------------------------------------------------- #
import numpy as np
from scipy.stats import gaussian_kde, norm

# Hall–Sheather n^{-1/3} bandwidth
def hall_sheather_bw(data: np.ndarray, p):
    """
    Hall–Sheather (1988) bandwidth for estimating the density at quantile p.
    """
    n  = len(data)
    z  = norm.ppf(p)
    c  = (1.5 * norm.pdf(z)**2) ** (1/3)
    sd = data.std(ddof=1)
    return c * n ** (-1/3) * sd

# Order statistic quantile + plug-in variance
def quantile_plugin(data, p, bandwidth='hall-sheather'):
    """
    Return the empirical quantile (order statistic) and the plug-in variance
    estimate  p(1−p) / [ n f(q̂)² ].
    """
    x = np.asarray(data, dtype=float)
    n = len(x)
    if not (0 < p < 1):
        raise ValueError("p must be in (0,1).")

    # Pure order statistic (no interpolation)
    k     = int(np.ceil(n * p)) - 1     
    qhat  = np.partition(x, k)[k]      

    # Density estimate f(q̂)
    if bandwidth == 'hall-sheather':
        bw = hall_sheather_bw(x, p)
        kde = gaussian_kde(x, bw_method=bw / x.std(ddof=1))
    else:
        kde = gaussian_kde(x, bw_method=bandwidth)
    fhat = kde.evaluate([qhat])[0]

    var_hat = p * (1 - p) / (n * fhat ** 2)
    return qhat, var_hat