import mpmath as mp
import numpy as np
import logging

from OAEGN_analysis.radial_rv import generalized_gamma_moments

def mp_f_U(u, T):
    """
    Density of U = cos(theta) for a random direction in R^T, using mpmath for high precision.
    """
    coef = mp.gamma(T/2) / (mp.sqrt(mp.pi) * mp.gamma((T-1)/2))
    return coef * (1 - u**2)**((T-3)/2)

def mp_Phi(x):
    """
    Standard normal CDF using mpmath (Phi(x)).
    """
    return mp.mpf('0.5') * (1 + mp.erf(x / mp.sqrt(2)))

def mp_exact_integrand(u, c, T):
    """
    (Phi(c/(2u)) - 0.5) * f_U(u;T) using mpmath.
    """
    if u == 0:
        return mp.mpf(0)   # Handles u=0 edge case for stability.
    return (mp_Phi(c / (2 * u)) - mp.mpf('0.5')) * mp_f_U(u, T)

def mp_compute_exact_sd_integral(c, T, prec=50, **quad_kwargs):
    """
    Compute I(c,T) = ∫_0^1 (Phi(c/(2u)) - 1/2) f_U(u;T) du
    with mpmath's high precision.
    """
    mp.mp.dps = prec
    result = mp.quad(lambda u: mp_exact_integrand(u, c, T), [0, 1], **quad_kwargs)
    return result


def mp_phi_T(u, T):
    """
    High-precision density function for U ~ N(0, 1/(T-1)), restricted to u > 0.
    """
    return mp.sqrt(T - 1) / mp.sqrt(2 * mp.pi) * mp.exp(-0.5 * (T - 1) * u ** 2)

def mp_approximate_integrand(u, c, T):
    """
    High-precision integrand: Phi(c / (2u)) * phi_T(u, T)
    """
    # Handle u = 0 if necessary (integration avoids u=0)
    return mp_Phi(c / (2 * u)) * mp_phi_T(u, T)

def mp_compute_approximate_sd_integral(c, T, prec=50):
    """
    Compute the integral ∫_{1e-6}^1 Phi(c/(2u)) * phi_T(u, T) du
    with high precision.
    """
    # Using [1e-6, 1] to avoid u=0 singularity
    mp.mp.dps = prec
    return mp.quad(lambda u: mp_approximate_integrand(u, c, T), [mp.mpf('1e-10'), mp.mpf('1')])


def compute_standard_elliptical_gaussian_statistical_distance(sigma, dim, mu_0, mu_1, prec=50, method='exact'):
    """
    Compute the statistical distance between two standard elliptical Gaussian distributions with different means but the same covariance matrix.
    The covariance matrix is a scaled identity matrix.
    """
    assert sigma > 0, "The standard deviation sigma must be positive"
    assert dim > 0, "The dimension dim must be positive"
    assert mu_0.shape == mu_1.shape, "The means must have the same shape"
    assert mu_0.ndim == 1 and mu_1.ndim == 1, "The means must be vectors"
    assert mu_0.shape[0] == dim and mu_1.shape[0] == dim, "The means must have the same dimension as the parameter dim"

    c = np.linalg.norm(mu_0 - mu_1, ord=2)/sigma

    if method == 'exact':
        return max(0, 4*mp_compute_exact_sd_integral(c=c, T=dim, prec=prec))
    elif method == 'approximate':
        return max(0, 4*mp_compute_approximate_sd_integral(c=c, T=dim, prec=prec) - mp.erf(mp.sqrt((dim - 1) / 2)))
    else:
        raise ValueError(f"Invalid method: {method}")

def compute_elliptical_gaussian_noise(epsilon, delta, dim, mu_0, mu_1, prec=50, method='exact', sigma_min=1e-10, sigma_max=1e10, tolerance=1e-8):
    """
    Compute required noise standard deviation sigma for (epsilon, delta)-differential privacy.
    Uses binary search to find smallest sigma where statistical distance ≤ delta.
    """
    assert epsilon == 0, "Only epsilon=0 is supported"
    assert 0 < delta < 1, "Delta must be between 0 and 1"
    assert mu_0.shape == mu_1.shape == (dim,), "Invalid mean vectors"
    if np.linalg.norm(mu_0 - mu_1, ord=2) == 0:
        return 0.0
    
    # Binary search
    while sigma_max - sigma_min > tolerance:
        sigma_mid = (sigma_min + sigma_max) / 2
        
        try:
            distance = compute_standard_elliptical_gaussian_statistical_distance(
                sigma=sigma_mid, dim=dim, mu_0=mu_0, mu_1=mu_1, prec=prec, method=method
            )
            distance = float(distance[0] if isinstance(distance, (tuple, list, np.ndarray)) else distance)
            
            if distance > delta:
                sigma_min = sigma_mid
            else:
                sigma_max = sigma_mid
        except:
            sigma_min = sigma_mid
    
    return sigma_max

def cdf_spherical_gaussian_plrv(x, k, T, sigma, s, atol=1e-12, precision=50, r_max=None, method='approximate'):
    """
    Compute the CDF of the privacy loss random variable X of the spherical Gaussian distribution with parameters k, T, sigma, s.

    Parameters:
    - x: input value Pr[X <= x]
    - k: degrees of freedom of the chi-distribution
    - T: dimension of the spherical Gaussian distribution
    - sigma: noise scale 
    - s: l2 sensitivity
    """
    mp.mp.dps = precision

    # Convert all parameters to mpf for high precision
    x = mp.mpf(x)
    k = int(k)
    T = int(T)
    sigma = mp.mpf(sigma)
    s = mp.mpf(s)
    atol = mp.mpf(atol)

    sqrt_T = mp.sqrt(T - 1)
    c1 = mp.mpf('0.5') * sigma * (T - k) / s
    c2 = mp.mpf('2.0') * s / sigma
    c3 = s**2 / sigma**2

    if method == 'exact':
        angle_cdf = lambda z: w_cdf_mp(z)
    elif method == 'approximate':
        angle_cdf = lambda z: norm_cdf_mp(z)
    
    def g_modified(r, w):
        # Use mp.log1p for high precision
        return r * w + c1 * mp.log1p(c2 * w / r + c3 / r**2) - x
    
    def w_star_func(r, w_min=mp.mpf(-1), w_max=mp.mpf(1)):
        if g_modified(r, w_min) > 0:
            return w_min
        if g_modified(r, w_max) < 0:
            return w_max

        return mp.findroot(lambda w: g_modified(r, w), (w_min, w_max), solver='bisect', tol=atol)
    
    def chi_pdf_mp(r, k):
        # PDF of chi distribution with k degrees of freedom at r
        # chi.pdf(r, df=k) = 2^{1-k/2} / Gamma(k/2) * r^{k-1} * exp(-r^2/2)
        if r < 0:
            return mp.mpf(0)
        norm_const = mp.power(2, 1 - k / 2) / mp.gamma(k / 2)
        return norm_const * mp.power(r, k - 1) * mp.exp(-r**2 / 2)

    def norm_cdf_mp(z):
        # Standard normal CDF using mpmath
        return mp.mpf('0.5') * (1 + mp.erf(sqrt_T * z / mp.sqrt(2)))

    def w_cdf_mp(z):
        return mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)

    def integrand(r):
        w_star = w_star_func(r)
        return chi_pdf_mp(r, k) * angle_cdf(w_star)
    
    if r_max is None:
        r_max = mp.mpf('20.0') * mp.sqrt(k)
    else:
        r_max = mp.mpf(r_max)
    
    return mp.quad(integrand, [mp.mpf(0), r_max], method='gauss-legendre', error=atol)


def compute_spherical_gaussian_privacy(epsilon, sigma, dof, dim, mu_0, mu_1, prec=50, method='approximate'):
    s = mp.mpf(np.linalg.norm(mu_0 - mu_1, ord=2))
    T = mp.mpf(dim)
    k = mp.mpf(dof)
    sigma = mp.mpf(sigma)
    epsilon = mp.mpf(epsilon)

    neg_term = (- 2*sigma**2*epsilon - s**2)/(2*sigma*s)
    pos_term = (2*sigma**2*epsilon - s**2)/(2*sigma*s)

    pos_cdf = cdf_spherical_gaussian_plrv(pos_term, k, T, sigma, s, precision=prec, method=method)
    neg_cdf = cdf_spherical_gaussian_plrv(neg_term, k, T, sigma, s, precision=prec, method=method)

    return max(0, float((mp.mpf(1) - pos_cdf[0]) - mp.exp(epsilon)*neg_cdf[0]))

def cdf_spherical_truncated_gaussian_plrv(x, T, m, sigma, s, atol=1e-12, precision=50, r_max=None, method='approximate'):
    """
    Compute the CDF of the privacy loss random variable X of the spherical truncated Gaussian distribution with parameters T, m, sigma, s.

    Parameters:
    - x: input value Pr[X <= x]
    - m: mean of the truncated Gaussian distribution
    - T: dimension of the spherical truncated Gaussian distribution
    - sigma: scale of the truncated Gaussian distribution
    - s: l2 sensitivity
    """
    assert method in ['approximate', 'exact'], "Invalid method"

    mp.mp.dps = precision

    # Convert all parameters to mpf for high precision
    x = mp.mpf(x)
    m = mp.mpf(m)
    T = int(T)
    sigma = mp.mpf(sigma)
    s = mp.mpf(s)
    atol = mp.mpf(atol)
    sqrt_T = mp.sqrt(T - 1)

    c1 = sigma**2*(T-1)/(2*s)
    c2 = m/s

    if method == 'exact':
        angle_cdf = lambda z: w_cdf_mp(z)
    elif method == 'approximate':
        angle_cdf = lambda z: norm_cdf_mp(sqrt_T * z)

    def g_modified(r, w):
        ln_arg = 2*s*w/r + s**2/r**2
        D_rw = mp.sqrt(r ** 2 + 2 * r * s * w + s ** 2)
        return c1 * mp.log1p(ln_arg) - c2 * (D_rw - r) + r * w - x
    
    def w_star_func(r, w_min=mp.mpf(-1), w_max=mp.mpf(1)):
        if g_modified(r, w_min) > 0:
            return w_min
        if g_modified(r, w_max) < 0:
            return w_max

        return mp.findroot(lambda w: g_modified(r, w), (w_min, w_max), solver='bisect', tol=atol)
    
    def trunc_gauss_pdf(r, m, sigma):
        """
        PDF of a Gaussian N(m, sigma^2) truncated to (0, +∞).
        Returns 0 for r <= 0.
        """
        if r < 0:
            return mp.mpf('0')
        # Normalising constant Φ(m/σ)
        denom = sigma * mp.sqrt(2 * mp.pi) * norm_cdf_mp(m / sigma)
        return mp.exp(- (r - m) ** 2 / (2 * sigma ** 2)) / denom

    def norm_cdf_mp(z):
        # Standard normal CDF using mpmath. MEAN = 0, VAR = 1
        #return mp.mpf('0.5') * (1 + mp.erf(sqrt_T * z / mp.sqrt(2)))
        return mp.mpf('0.5') * (1 + mp.erf(z / mp.sqrt(2)))
    
    def w_cdf_mp(z):
        return mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)

    def integrand(r):
        w_star = w_star_func(r)
        return trunc_gauss_pdf(r, m, sigma) * angle_cdf(w_star)

    if r_max is None:
        r_max = m +mp.mpf('20.0') * sigma ## do we have to set this?
    else:
        r_max = mp.mpf(r_max)
    
    return mp.quad(integrand, [mp.mpf(0), r_max], method='gauss-legendre', error=atol)


def compute_spherical_truncated_gaussian_privacy(epsilon, sigma, m, dim, mu_0, mu_1, prec=50, method='approximate'):
    assert m > 0, "The mean must be positive"
    assert sigma > 0, "The standard deviation must be positive"
    assert m**2 <= 4*sigma**2*(dim-1), "the mean and sigam setting has not been implemented"

    s = mp.mpf(np.linalg.norm(mu_0 - mu_1, ord=2))
    T = mp.mpf(dim)
    sigma = mp.mpf(sigma)
    m = mp.mpf(m)
    epsilon = mp.mpf(epsilon)

    neg_term = (- 2*sigma**2*epsilon - s**2)/(2*s)
    pos_term = (2*sigma**2*epsilon - s**2)/(2*s)

    pos_cdf = cdf_spherical_truncated_gaussian_plrv(pos_term, T, m, sigma, s, precision=prec, method=method)
    neg_cdf = cdf_spherical_truncated_gaussian_plrv(neg_term, T, m, sigma, s, precision=prec, method=method)

    return max(0, float((mp.mpf(1) - pos_cdf[0]) - mp.exp(epsilon)*neg_cdf[0]))

def cdf_spherical_gamma_plrv_lower_part(x, T, alpha, theta, s, atol=1e-12, precision=50, r_max=None, method='exact'):
    """
    Compute the CDF of the privacy loss random variable X of the spherical gamma distribution with parameters T, alpha, theta, s.

    Parameters:
    - x: input value Pr[X <= x]
    - T: dimension of the spherical gamma distribution
    - alpha: shape parameter of the gamma distribution
    - theta: scale parameter of the gamma distribution
    - s: l2 sensitivity
    """
    mp.mp.dps = precision

    assert alpha > 0, "The shape parameter alpha must be positive"
    assert alpha <= T, "The shape parameter alpha must be less than or equal to the dimension T"
    assert theta > 0, "The scale parameter theta must be positive"
    assert s > 0, "The l2 sensitivity s must be positive"

    # Convert all parameters to mpf for high precision
    x = mp.mpf(x)
    T = int(T)
    alpha = mp.mpf(alpha)
    theta = mp.mpf(theta)
    s = mp.mpf(s)
    atol = mp.mpf(atol)

    sqrt_T = mp.sqrt(T - 1)
    c1 = theta * (T - alpha) / 2

    if method == 'exact':
        angle_cdf = lambda z: w_cdf_mp(z)
    elif method == 'approximate':
        angle_cdf = lambda z: norm_cdf_mp(sqrt_T * z)
    
    def g_modified(r, w):
        # Use mp.log1p for high precision
        D_rw = mp.sqrt(r ** 2 + 2 * r * s * w + s ** 2)
        return D_rw - r + c1 * mp.log1p(2*s*w/r  + s**2/r**2) - x
    
    def w_star_func(r, w_min=mp.mpf(-1), w_max=mp.mpf(1)):
        if g_modified(r, w_min) > 0:
            return w_min
        if g_modified(r, w_max) < 0:
            return w_max

        return mp.findroot(lambda w: g_modified(r, w), (w_min, w_max), solver='bisect', tol=atol)
    
    def gamma_pdf_mp(r, alpha, theta):
        # PDF of gamma distribution with shape parameter alpha and scale parameter theta at r
        # gamma.pdf(r, alpha, theta) = r^{alpha-1} * exp(-r/theta) / (theta^alpha * Gamma(alpha))
        if r < 0:
            return mp.mpf(0)
        return (r ** (alpha - 1)) * mp.exp(-r / theta) / (theta ** alpha * mp.gamma(alpha))

    def norm_cdf_mp(z):
        # Standard normal CDF using mpmath
        return mp.mpf('0.5') * (1 + mp.erf(z / mp.sqrt(2)))
    
    def w_cdf_mp(z):
        return mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)

    def integrand(r):
        w_star = w_star_func(r)
        return gamma_pdf_mp(r, alpha, theta) * angle_cdf(w_star)

    mean = alpha * theta
    std = mp.sqrt(alpha) * theta
    
    if r_max is None:
        r_max = mean + mp.mpf('20.0') * std
    else:
        r_max = mp.mpf(r_max)
    
    return mp.quad(integrand, [mp.mpf(0), r_max], method='gauss-legendre', error=atol)


def cdf_spherical_gamma_plrv_upper_part(x, T, alpha, theta, s, atol=1e-12, precision=100, r_max=None, method='exact'):
    mp.mp.dps = precision

    assert alpha > 0, "The shape parameter alpha must be positive"
    assert alpha >= T, "The shape parameter alpha must be greater than the dimension T"
    assert theta > 0, "The scale parameter theta must be positive"
    assert s > 0, "The l2 sensitivity s must be positive"

    # Convert all parameters to mpf for high precision
    x = mp.mpf(x)
    T = int(T)
    alpha = mp.mpf(alpha)
    theta = mp.mpf(theta)
    s = mp.mpf(s)
    atol = mp.mpf(atol)

    sqrt_T = mp.sqrt(T - 1)
    c1 = theta * (T - alpha) / 2

    if method == 'exact':
        angle_cdf = lambda z: w_cdf_mp(z)
    elif method == 'approximate':
        angle_cdf = lambda z: norm_cdf_mp(sqrt_T * z)
    
    def h_modified(r, y):
        return r*(mp.sqrt(y) - 1) + c1 * mp.log(y) - x
    
    def gamma_pdf_mp(r, alpha, theta):
        # PDF of gamma distribution with shape parameter alpha and scale parameter theta at r
        # gamma.pdf(r, alpha, theta) = r^{alpha-1} * exp(-r/theta) / (theta^alpha * Gamma(alpha))
        if r < 0:
            return mp.mpf(0)
        return (r ** (alpha - 1)) * mp.exp(-r / theta) / (theta ** alpha * mp.gamma(alpha))
    

    def find_two_roots(r, eps=1e-12, Ymax=1e3):
        # stationary point
        y0 = 4*c1**2/r**2

        # bracket left root in [eps, y0]
        y1 = interval_bisection(lambda y: h_modified(r, y), eps, y0, tol=atol)

        # find Ymax large enough so f(Ymax)>0
        while h_modified(r, Ymax) < 0:
            Ymax *= 2

        # bracket right root in [y0, Ymax]
        y2 = mp.findroot(lambda y: h_modified(r, y), (y0, Ymax), solver='bisect', tol=atol)
        return y1, y2
    
    def norm_cdf_mp(z):
        # Standard normal CDF using mpmath
        return mp.mpf('0.5') * (1 + mp.erf(z / mp.sqrt(2)))
    
    def w_cdf_mp(z):
        return mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)
    
    def integrand(r):
        a = s/r
        y0 = 4*c1**2/r**2

        if h_modified(r, y0) >= 0:
            return 0

        y1, y2 = find_two_roots(r)
        if y0 <= (1 - a)**2:
            y1 = (1 - a)**2
        if y0 >= (1 + a)**2:
            y2 = (1 + a)**2
        
        if y1 > y2:
            return 0
        
        w1 = (y1 - a**2  - 1) / (2*a)
        w2 = (y2 - a**2  - 1) / (2*a)

        return gamma_pdf_mp(r, alpha, theta) * (angle_cdf(w2) - angle_cdf(w1))
    
    mean = alpha * theta
    std = mp.sqrt(alpha) * theta
    
    if r_max is None:
        r_max = mean + mp.mpf('20.0') * std
    else:
        r_max = mp.mpf(r_max)
    
    return mp.quad(integrand, [mp.mpf(0), r_max], method='gauss-legendre', error=atol)
    

def compute_spherical_gamma_privacy(epsilon, theta, alpha, dim, mu_0, mu_1, atol=1e-12, prec=50, method='exact'):
    assert method in ['exact', 'approximate'], "Invalid method"

    s = mp.mpf(np.linalg.norm(mu_0 - mu_1, ord=2))
    T = mp.mpf(dim)
    theta = mp.mpf(theta)
    alpha = mp.mpf(alpha)
    epsilon = mp.mpf(epsilon)

    pos_term = epsilon*theta
    neg_term = -epsilon*theta

    if alpha > T:
        pos_cdf = cdf_spherical_gamma_plrv_upper_part(pos_term, T, alpha, theta, s, atol=atol, precision=prec, method=method)
        neg_cdf = cdf_spherical_gamma_plrv_upper_part(neg_term, T, alpha, theta, s, atol=atol, precision=prec, method=method)
    else:
        pos_cdf = cdf_spherical_gamma_plrv_lower_part(pos_term, T, alpha, theta, s, atol=atol, precision=prec, method=method)
        neg_cdf = cdf_spherical_gamma_plrv_lower_part(neg_term, T, alpha, theta, s, atol=atol, precision=prec, method=method)

    return max(0, float((mp.mpf(1) - pos_cdf[0]) - mp.exp(epsilon)*neg_cdf[0]))

def interval_bisection(f, a, b, tol=1e-12, maxiter=200):
    fa, fb = f(a), f(b)
    if fa * fb > 0:
        raise ValueError("Root is not bracketed: f(a) and f(b) have same sign.")

    for i in range(maxiter):
        c = (a + b) / 2
        fc = f(c)
        # Stop if interval is small enough
        if abs(b - a) < tol:
            return c
        # Replace whichever end has the same sign as f(c)
        if fa * fc < 0:
            b, fb = c, fc
        else:
            a, fa = c, fc
    raise RuntimeError("Bisection did not converge")

def compute_truncated_gaussian_noise_given_sigma(epsilon, delta, sigma, dim, mu_0, mu_1, prec=50, method='exact', m_low=None, m_high=None, atol=1e-12, rtol=1e-3):
    """
    Find the smallest m such that the privacy loss is at most delta for given epsilon, sigma, and other parameters.
    Uses binary search on m.
    """
    assert method in ['exact', 'approximate'], "Invalid method"
    if m_low is None:
        m_low = 1e-8
    if m_high is None:
        m_high = sigma*np.sqrt(dim)

    delta_val = compute_spherical_truncated_gaussian_privacy(epsilon, sigma, m_high, dim, mu_0, mu_1, prec=prec, method=method)
    logging.info(f"[m_high: {m_high}, sigma: {sigma}, dim: {dim}, epsilon: {epsilon}] delta_val: {delta_val}")

    while delta_val >= delta:
        m_high *= 2
        delta_val = compute_spherical_truncated_gaussian_privacy(epsilon, sigma, m_high, dim, mu_0, mu_1, prec=prec, method=method)
        logging.info(f"[m_high: {m_high}, sigma: {sigma}, dim: {dim}, epsilon: {epsilon}] delta_val: {delta_val}")

    while abs(delta_val - delta) > atol:
        m_mid = (m_low + m_high) / 2
        delta_val = compute_spherical_truncated_gaussian_privacy(
            epsilon, sigma, m_mid, dim, mu_0, mu_1, prec=prec, method=method
        )
        logging.info(f"[m_mid: {m_mid}, sigma: {sigma}, dim: {dim}, epsilon: {epsilon}] delta_val: {delta_val}")
        if delta_val > delta:
            m_low = m_mid
        else:
            m_high = m_mid
        if abs(m_high - m_low) < atol:
            logging.info("[Warning] Binary search stopped due to tolerance: |m_high - m_low| < tol. Cannot achieved the desired accuracy.")
            break

    return m_mid
    
def compute_spherical_gaussian_privacy(epsilon, sigma, dof, dim, mu_0, mu_1, prec=50, method='approximate'):
    s = mp.mpf(np.linalg.norm(mu_0 - mu_1, ord=2))
    T = mp.mpf(dim)
    k = mp.mpf(dof)
    sigma = mp.mpf(sigma)
    epsilon = mp.mpf(epsilon)

    neg_term = (- 2*sigma**2*epsilon - s**2)/(2*sigma*s)
    pos_term = (2*sigma**2*epsilon - s**2)/(2*sigma*s)

    pos_cdf = cdf_spherical_gaussian_plrv(pos_term, k, T, sigma, s, precision=prec, method=method)
    neg_cdf = cdf_spherical_gaussian_plrv(neg_term, k, T, sigma, s, precision=prec, method=method)

    return max(0, float((mp.mpf(1) - pos_cdf[0]) - mp.exp(epsilon)*neg_cdf[0]))

def cdf_spherical_Weibull_plrv(x, T, k, lam, s, atol=1e-12, precision=50, r_max=None, method='approximate'):
    """
    Compute the CDF of the privacy loss random variable X of the spherical Weibull distribution with parameters T, k, lam, s.

    Parameters:
    - x: input value Pr[X <= x]
    - k: shape parameter of the Weibull distribution
    - lam: scale parameter of the Weibull distribution
    - T: dimension of the spherical Weibull distribution
    - s: l2 sensitivity
    """
    assert method in ['approximate', 'exact'], "Invalid method"

    mp.mp.dps = precision

    # Convert all parameters to mpf for high precision
    x = mp.mpf(x)
    k = mp.mpf(k)
    lam = mp.mpf(lam)
    T = int(T)
    s = mp.mpf(s)
    atol = mp.mpf(atol)
    sqrt_T = mp.sqrt(T - 1)

    c1 = (k - T) / 2

    if method == 'exact':
        angle_cdf = lambda z: w_cdf_mp(z)
    elif method == 'approximate':
        angle_cdf = lambda z: norm_cdf_mp(sqrt_T * z)

    def g_modified(r, w):
        ln_arg = 2*s*w/r + s**2/r**2
        D_rw = mp.sqrt(r ** 2 + 2 * r * s * w + s ** 2)
        return c1 * mp.log1p(ln_arg) + (r/lam)**k - (D_rw/lam)**k - x
    
    def w_star_func(r, w_min=mp.mpf(-1), w_max=mp.mpf(1)):
        if g_modified(r, w_max) > 0:
            return w_max
        if g_modified(r, w_min) < 0:
            return w_min

        return mp.findroot(lambda w: g_modified(r, w), (w_min, w_max), solver='bisect', tol=atol)
    
    def weibull_pdf(r):
        if r < 0:
            return mp.mpf(0)
        return (k/lam) * (r/lam)**(k-1) * mp.exp(-(r/lam)**k)

    def norm_cdf_mp(z):
        # Standard normal CDF using mpmath
        return mp.mpf('0.5') * (1 + mp.erf(z / mp.sqrt(2)))
    
    def w_cdf_mp(z):
        return mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)

    def integrand(r):
        w_star = w_star_func(r)
        return weibull_pdf(r) * (1 - angle_cdf(w_star))

    r_max = 20 * lam * (-mp.log(atol))**(1.0 / k)

    return mp.quad(integrand, [mp.mpf(0), r_max], error=atol)


def cdf_spherical_Gamma_typeI_plrv(x, T, beta, p, s, atol=1e-12, precision=50, method='exact'):
    """
    Compute the CDF of the privacy loss random variable X of the spherical generalized Gamma distribution with parameters T, beta, p. Where alpha = T - 1

    Parameters:
    - x: input value Pr[X <= x]
    - beta: shape parameter of the Generalized Gamma distribution
    - p: scale parameter of the Generalized Gamma distribution
    - T: dimension of the spherical Generalized Gamma distribution
    - s: l2 sensitivity
    """
    assert method in ['approximate', 'exact'], "Invalid method"

    mp.mp.dps = precision

    # Convert all parameters to mpf for high precision
    x = mp.mpf(x)
    beta = mp.mpf(beta)
    p = mp.mpf(p)
    T = int(T)
    s = mp.mpf(s)
    atol = mp.mpf(atol)
    sqrt_T = mp.sqrt(T - 1)
    alpha = T - 1

    if method == 'exact':
        angle_cdf = lambda z: w_cdf_mp(z)
    elif method == 'approximate':
        angle_cdf = lambda z: norm_cdf_mp(sqrt_T * z)

    def g_modified(r, w):
        D_rw = mp.sqrt(r ** 2 + 2*s*w*r + s ** 2)
        return beta*(r**p - D_rw**p) - x
    
    def w_star_func(r, w_min=mp.mpf(-1), w_max=mp.mpf(1)):
        if g_modified(r, w_max) > 0:
            return w_max
        if g_modified(r, w_min) < 0:
            return w_min

        return mp.findroot(lambda w: g_modified(r, w), (w_min, w_max), solver='bisect', tol=atol)
    
    def generalized_gamma_pdf(r):
        if r < 0:
            return mp.mpf(0)
        norm_const = p * mp.power(beta, (alpha + 1)/p) / mp.gamma((alpha + 1)/p)
        pdf = norm_const * mp.power(r, alpha) * mp.exp(-beta * mp.power(r, p))
        return pdf

    def norm_cdf_mp(z):
        # Standard normal CDF using mpmath
        return mp.mpf('0.5') * (1 + mp.erf(z / mp.sqrt(2)))
    
    def w_cdf_mp(z):
        return mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)

    def integrand(r):
        w_star = w_star_func(r)
        return generalized_gamma_pdf(r) * (1 - angle_cdf(w_star))

    r_max = generalized_gamma_moments(alpha, beta, p, moment = 1, prec=50)*200
    
    return mp.quad(integrand, [mp.mpf(0), r_max], error=atol)


def compute_spherical_Gamma_typeI_privacy(epsilon, beta, p, dim, mu_0, mu_1, atol=1e-12, prec=50, method='exact'):
    s = mp.mpf(np.linalg.norm(mu_0 - mu_1, ord=2))
    T = mp.mpf(dim)
    beta = mp.mpf(beta)
    p = mp.mpf(p)
    epsilon = mp.mpf(epsilon)

    neg_term = epsilon
    pos_term = -epsilon

    pos_cdf = cdf_spherical_Gamma_typeI_plrv(pos_term, T, beta, p, s, atol=atol, precision=prec, method=method)
    neg_cdf = cdf_spherical_Gamma_typeI_plrv(neg_term, T, beta, p, s, atol=atol, precision=prec, method=method)

    return max(0, float(pos_cdf[0] - mp.exp(epsilon)*(1 - neg_cdf[0])))


def cdf_spherical_generalized_gamma_plrv(x, T, alpha, beta, p, s, atol=1e-15, precision=50, method='exact'):
    """
    Compute the CDF of the privacy loss random variable X of the spherical generalized Gamma distribution with parameters T, alpha, beta, p.

    Parameters:
    - x: input value Pr[X <= x]
    - alpha: shape parameter of the Generalized Gamma distribution
    - beta: shape parameter of the Generalized Gamma distribution
    - p: scale parameter of the Generalized Gamma distribution
    - T: dimension of the spherical Generalized Gamma distribution
    - s: l2 sensitivity
    """
    assert method in ['approximate', 'exact'], "Invalid method"

    mp.mp.dps = precision

    # Convert all parameters to mpf for high precision
    x = mp.mpf(x)
    alpha = mp.mpf(alpha)
    beta = mp.mpf(beta)
    p = mp.mpf(p)
    T = int(T)
    s = mp.mpf(s)
    atol = mp.mpf(atol)
    sqrt_T = mp.sqrt(T - 1)

    c1 = (alpha + 1 - T)/2

    if method == 'exact':
        angle_cdf = lambda z: w_cdf_mp(z)
    elif method == 'approximate':
        angle_cdf = lambda z: norm_cdf_mp(sqrt_T * z)

    def g_modified(r, w):
        ln_arg = 2*s*w/r + s**2/r**2
        D_rw = mp.sqrt(r ** 2 + 2*s*w*r + s ** 2)
        return c1*mp.log1p(ln_arg) + beta*(r**p - D_rw**p) - x
    
    def w_star_func(r, w_min=mp.mpf(-1), w_max=mp.mpf(1), *, max_iter=80):
        """
        Solve for w in [w_min, w_max] such that g_modified(r, w) = 0.

        IMPORTANT: this must *never* return None, since downstream code assumes
        a numeric w in [-1, 1]. We use a robust bisection that returns a valid
        boundary value when the root is at the boundary or when numerical
        tolerances are tight.
        """
        g_hi = g_modified(r, w_max)
        if g_hi >= 0:
            return w_max

        g_lo = g_modified(r, w_min)
        if g_lo <= 0:
            return w_min

        # Now we should have g_lo > 0 and g_hi < 0, so a root is bracketed.
        lo = mp.mpf(w_min)
        hi = mp.mpf(w_max)

        # Choose a practical tolerance for w; the integrator tolerance can be
        # extremely tight and make root-finding brittle.
        w_atol = mp.mpf("1e-12")
        if atol > 0:
            w_atol = max(w_atol, atol)

        for _ in range(max_iter):
            mid = (lo + hi) / 2
            g_mid = g_modified(r, mid)

            if g_mid >= 0:
                lo = mid
            else:
                hi = mid

            if hi - lo <= w_atol:
                break

        # Return the last point with g >= 0 (monotone-safe convention).
        return lo
    
    def generalized_gamma_pdf(r):
        if r < 0:
            return mp.mpf(0)
        norm_const = p * mp.power(beta, (alpha + 1)/p) / mp.gamma((alpha + 1)/p)
        pdf = norm_const * mp.power(r, alpha) * mp.exp(-beta * mp.power(r, p))
        return pdf

    def norm_cdf_mp(z):
        # Standard normal CDF using mpmath
        return mp.mpf('0.5') * (1 + mp.erf(z / mp.sqrt(2)))
    
    def w_cdf_mp(z):
        return mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)

    def integrand(r):
        w_star = w_star_func(r)
        return generalized_gamma_pdf(r) * (1 - angle_cdf(w_star))

    r_max = generalized_gamma_moments(alpha, beta, p, moment = 1, prec=50)*200 # is it clear that we need a max?
    mp.mp.dps = precision
    
    return mp.quad(integrand, [mp.mpf(0), r_max], error=atol)

def compute_spherical_generalized_gamma_privacy(epsilon, alpha, beta, p, dim, mu_0, mu_1, prec=50, method='exact'):
    assert alpha >= 0 and alpha <= dim-1, "alpha must be non-negative and less than dim-1"
    assert beta > 0, "beta must be positive"
    assert p > 0, "p must be positive"

    s = mp.mpf(np.linalg.norm(mu_0 - mu_1, ord=2))
    T = mp.mpf(dim)
    alpha = mp.mpf(alpha)
    beta = mp.mpf(beta)
    p = mp.mpf(p)
    epsilon = mp.mpf(epsilon)

    neg_term = epsilon
    pos_term = -epsilon

    pos_cdf = cdf_spherical_generalized_gamma_plrv(pos_term, T, alpha, beta, p, s, precision=prec, method=method)
    neg_cdf = cdf_spherical_generalized_gamma_plrv(neg_term, T, alpha, beta, p, s, precision=prec, method=method)

    return max(0, float(pos_cdf[0] - mp.exp(epsilon)*(1 - neg_cdf[0])))