import numpy as np
import mpmath as mp
from scipy.linalg import eigh_tridiagonal
from multiprocessing import Pool
from bisect import bisect_left
from pathlib import Path

import os

from OAGN_analysis.accounting_gaussian import _pmf_power_fft, _tail_prob_from_pmf
from OAEGN_analysis.radial_rv import generalized_gamma_moments
import mpmath as mp
from numpy.polynomial.legendre import leggauss  
from scipy.special import betainc as sp_betainc
import time
from scipy.special import erfcinv

# Global discretizer reference for worker processes (set via Pool initializer)
_WORKER_DISC = None

def _init_worker(discretizer):
    """
    Initializer to set a global reference to the discretizer in each worker
    process, avoiding repeated pickling/copying per task.
    """
    global _WORKER_DISC
    _WORKER_DISC = discretizer
    # Set precision and warm up memory pages to reduce first-task variance
    try:
        mp.mp.dps = discretizer.precision
        # Touch large arrays once (copy-on-write pages fault in predictably)
        _ = discretizer.w_tables[0]
        for arr in discretizer.gr_tables:
            if len(arr) > 0:
                _ = arr[0] + arr[-1]
    except Exception:
        pass

def gauss_laguerre_jacobi_solver(K, alpha):
    """
        Compute the eigenvalues and eigenvectors of the Jacobi matrix 
        for the generalized Laguerre polynomial with weight x^alpha e^{-x}.

        Note that the eigenvalues are the nodes and the first component of the eigenvectors are the weights.
    """
    # Laguerre polynomial with K degrees
    k = np.arange(0, K, dtype=np.float64)  
    # Construct the Jacobi (tridiagonal) matrix: diagonal and off-diagonal values         
    diag_elements = 2.0*k + 1.0 + alpha                       
    off_diag_elements = np.sqrt((k[1:]) * (k[1:] + alpha))         

    # Solve the Jacobi matrix
    eig_vals, eig_vecs = eigh_tridiagonal(diag_elements, off_diag_elements) 
    return diag_elements, off_diag_elements, eig_vals, eig_vecs

# mp tridiagonal solve: (J - shift I) y = b  (Thomas algorithm)
def _solve_tridiag_shift_mp(d, e, b, shift):
    n = len(d)
    dd = [mp.mpf(di) - shift for di in d]
    ee = [mp.mpf(ei)        for ei in e]
    bb = [mp.mpf(bi)        for bi in b]

    # forward elimination
    for i in range(1, n):
        m = ee[i-1] / dd[i-1]
        dd[i] -= m * ee[i-1]
        bb[i] -= m * bb[i-1]

    # back substitution
    y = [mp.mpf('0')]*n
    y[-1] = bb[-1] / dd[-1]
    for i in range(n-2, -1, -1):
        y[i] = (bb[i] - ee[i]*y[i+1]) / dd[i]
    return y

def _J_times_q(d, e, q):
    n = len(d)
    Jq = [mp.mpf('0')]*n
    Jq[0] = d[0]*q[0] + e[0]*q[1]
    for i in range(1, n-1):
        Jq[i] = e[i-1]*q[i-1] + d[i]*q[i] + e[i]*q[i+1]
    Jq[-1] = e[-1]*q[-2] + d[-1]*q[-1]
    return Jq

def _process_single_eigenpair(args):
    """
    Process a single eigenvalue/eigenvector pair for parallel execution.
    """
    i, x0, V0, dmp, emp, mu0, max_iters, tol, return_logw = args
    
    lam = mp.mpf(x0[i])                        # Rayleigh shift
    # seed eigenvector from fast solve (column i), convert to mp & normalize
    q = [mp.mpf(v) for v in V0[:, i]]
    qn = mp.sqrt(mp.fsum(qq*qq for qq in q))
    q = [qq/qn for qq in q]

    for _ in range(max_iters):
        z = _solve_tridiag_shift_mp(dmp, emp, q, shift=lam)
        zn = mp.sqrt(mp.fsum(zz*zz for zz in z))
        q  = [zz/zn for zz in z]

        Jq = _J_times_q(dmp, emp, q)
        lam_new = mp.fsum(q[j]*Jq[j] for j in range(len(q)))

        # residual and convergence test
        r = [Jq[j] - lam_new*q[j] for j in range(len(q))]
        rnorm = mp.sqrt(mp.fsum(ri*ri for ri in r))
        if rnorm <= tol or mp.fabs(lam_new - lam) <= tol*max(1, mp.fabs(lam)):
            lam = lam_new
            break
        lam = lam_new

    # Golub–Welsch: weight = μ0 * (first component)^2  (nonnegative in mp)
    w_i = mu0 * q[0]*q[0]
    w = w_i if not return_logw else mp.log(mu0) + 2*mp.log(mp.fabs(q[0]))
    
    return i, lam, w

def _process_single_r_value(args):
    """
    Process a single r value for parallel execution in Discretize_SGG_PLRV.
    """
    r, w_descending_table, func_g, _enforce_strict_increasing = args
    
    # ensure gs (output of g(r,w)) ascending, so we can use in binary search code 
    gs = [func_g(r, w) for w in w_descending_table]
    gs = _enforce_strict_increasing(gs)  # tiny eps to avoid ties
    
    return gs

def _process_y_chunk_cdf(y_chunk):
    """
    Process a chunk of y values for parallel execution in cdf_vec.
    Uses a global discretizer set by the Pool initializer to avoid
    pickling the instance on every task.
    """
    discretizer = _WORKER_DISC
    
    rk_len = len(discretizer.rk)
    w_stars = np.empty(rk_len, dtype=np.float64)
    w_star_fn = discretizer._w_star_from_table_np
    sw_fn = discretizer._S_W_vec
    w_dot = discretizer.wk_np

    out = []
    for y in y_chunk:
        # float64 path: interpolate w_star and evaluate S_W vectorized
        y_f = float(y)
        for k in range(rk_len):
            w_stars[k] = w_star_fn(k, y_f)
        sw_vals = sw_fn(w_stars)
        acc = float(np.dot(w_dot, sw_vals))
        # clip tiny overshoots
        if acc < 0: acc = 0.0
        if acc > 1: acc = 1.0
        out.append(acc)

    return out

def gauss_laguerre_mixed_mp(K, alpha=0.0, dps=80, max_iters=6, tol=None, return_logw=False, workers=1):
    """
    The function computes very accurate Gauss–Laguerre nodes and weights for the weight x^alpha e^{-x} on [0,∞).
    The idea is to start with the standard double-precision tridiagonal eigensolver to get an initial approximation of the eigenvalues and eigenvectors, 
    and then refine each eigenpair with inverse/Rayleigh {max_iters} iterations in mp.
    """
    assert alpha > -1, "alpha must be > -1"

    # seed in float64
    d64, e64, x0, V0 = gauss_laguerre_jacobi_solver(K, alpha)

    # promote to mp
    mp.mp.dps = dps
    dmp = [mp.mpf(di) for di in d64]
    emp = [mp.mpf(ei) for ei in e64]
    mu0 = mp.gamma(alpha + 1)             

    if tol is None:
        tol = mp.mpf('1e-{}'.format(max(30, dps-8)))

    # Simple on-disk cache to avoid recomputation for the same parameters
    cache_dir = Path(__file__).resolve().parents[2] / "data" / "gauss_laguerre_cache"
    cache_dir.mkdir(parents=True, exist_ok=True)
    tol_key = f"{float(tol):.6e}"
    alpha_key = f"{float(alpha):.6g}"
    cache_path = cache_dir / f"K{K}_alpha{alpha_key}_dps{dps}_maxit{max_iters}_tol{tol_key}_log{int(return_logw)}.npz"

    if cache_path.exists():
        loaded = np.load(cache_path)
        xs = [mp.mpf(v) for v in loaded["xs"]]
        ws = [mp.mpf(v) for v in loaded["ws"]]
        return xs, ws
    
    args_list = [(i, x0, V0, dmp, emp, mu0, max_iters, tol, return_logw) for i in range(K)]
    with Pool(processes=workers) as pool:
        results = pool.map(_process_single_eigenpair, args_list)
    
    xs = [None]*K
    ws = [None]*K
    for i, lam, w in results:
        xs[i] = lam
        ws[i] = w

    # sort by nodes and reorder weights accordingly
    order = sorted(range(K), key=lambda i: xs[i])
    x_sorted = [xs[i] for i in order]
    w_sorted = [ws[i] for i in order]

    # cache to disk (float64) for reuse
    try:
        np.savez_compressed(cache_path, xs=np.array([float(v) for v in x_sorted], dtype=np.float64),
                            ws=np.array([float(v) for v in w_sorted], dtype=np.float64))
    except Exception:
        pass

    return x_sorted, w_sorted

def check_gauss_laguerre(x, w, *, alpha=0.0, kmax=6, dps=None, w_is_log=False):
    """
    Check the accuracy of the Gauss–Laguerre nodes and weights.
    """
    if dps is not None:
        mp.mp.dps = int(dps)

    x = [mp.mpf(xi) for xi in x]
    if w_is_log:
        logw = [mp.mpf(wi) for wi in w]
    else:
        w = [mp.mpf(wi) for wi in w]

    def logsumexp(vals):
        m = max(vals)
        return m + mp.log(mp.fsum(mp.e**(v-m) for v in vals))

    def sum_w():
        return mp.e**logsumexp(logw) if w_is_log else mp.fsum(w)

    def sum_w_xk(k):
        return (mp.e**logsumexp([logw[i] + k*mp.log(x[i]) for i in range(len(x))])
                if w_is_log else
                mp.fsum(w[i]*(x[i]**k) for i in range(len(x))))

    ok = True
    # finite & increasing
    if not all(mp.isfinite(v) for v in x): ok = False
    if not (w_is_log or all(mp.isfinite(v) for v in w)): ok = False
    if not all((x[i+1]-x[i]) > 0 for i in range(len(x)-1)): ok = False

    # weight sum and a few moments
    target = mp.gamma(alpha+1)
    if mp.fabs(sum_w() - target) > mp.mpf('1e-30')*max(1, mp.fabs(target)):
        ok = False
    for k in range(int(kmax)+1):
        rhs = mp.gamma(k+alpha+1)
        if mp.fabs(sum_w_xk(k) - rhs) > mp.mpf('1e-30')*max(1, mp.fabs(rhs)):
            ok = False

    return bool(ok)


# this is the discretization of the privacy loss random variable for SGG under P using the PLRV method
class Discretize_SGG_PLRV:
    def __init__(self, T, alpha, beta, p, s, K=128, precision=50, w_grid_size=2049, workers=50):
        self.T = int(T)
        self.alpha = mp.mpf(alpha)
        self.beta  = mp.mpf(beta)
        self.p     = mp.mpf(p)
        self.s     = mp.mpf(s)
        self.precision = precision
        self.workers = workers
        self._pool = None
        self._pool_workers = workers
        self.use_persistent_pool = False
        mp.mp.dps = precision

        if self.T <= 10:
            self.rk, self.wk = self._build_radial_nodes_two_panel(K=K, tau=None, ratio=0.75, alpha_tail=0.0)
        else:
            # Compute (customized) Gauss–Laguerre nodes/weights
            kappa = (self.alpha + 1) / self.p
            alpha_GL = kappa - 1             
            self.tk, self.ak = gauss_laguerre_mixed_mp(K=K, alpha=float(alpha_GL), dps=max(80, precision), workers=workers)
            mp.mp.dps = precision

            # rk/wk: mapped node/weight for radial random variable R
            Gk = mp.gamma(kappa)
            self.rk = [ (tk/self.beta)**(1/self.p) for tk in self.tk ]
            self.wk = [ ak / Gk for ak in self.ak ]

            S = mp.fsum(self.wk)
            self.wk = [wk/S for wk in self.wk]
        # float64 copy for fast dot products
        self.wk_np = np.array([float(w) for w in self.wk], dtype=np.float64)

        # Precompute function table g_r(w) 
        self.w_grid_size = w_grid_size
        self.c1 = (self.alpha + 1 - self.T) / 2
        self.w_tables = np.linspace(-1.0, 1.0, w_grid_size)[::-1]
        self.w_tables_mp = [mp.mpf(w) for w in self.w_tables]
        self.mp_neg_one = mp.mpf(-1)
        self.mp_one = mp.mpf(1)
        self.mp_zero = mp.mpf(0)

        self.gr_tables_mp  = [] 
        self.gr_tables = []   

        # Compute g(r,w) table for different r and w in parallel
        args_list = [(r, self.w_tables_mp, self.func_g, self._enforce_strict_increasing) 
                     for r in self.rk]
        with Pool(processes=workers) as pool:
            results = pool.map(_process_single_r_value, args_list)
        
        for gs in results:
            self.gr_tables_mp.append(gs)
            self.gr_tables.append(np.array([float(v) for v in gs], dtype=np.float64))

    def func_g(self, r, w):
        # g(r, w) = c1*log1p(2 s w/r + s^2/r^2) + beta*(r^p - (r^2+2swr+s^2)^{p/2})
        s, p, beta, c1 = self.s, self.p, self.beta, self.c1
        ln_arg = 2*s*w/r + (s**2)/(r**2)
        # avoid hitting log1p(-1) exactly
        if ln_arg <= -1:
            ln_arg = -1 + mp.mpf('1e-40')
        D = mp.sqrt(r*r + 2*s*w*r + s*s)
        return c1*mp.log1p(ln_arg) + beta*((r**p) - (D**p))

    def _enforce_strict_increasing(self, gs):
        out, cur = [], -mp.inf
        tiny = mp.mpf('1e-30')
        for v in gs:
            if not mp.isfinite(v):
                v = cur if mp.isfinite(cur) else mp.mpf(0)
            if v <= cur:
                v = cur + tiny
            cur = v
            out.append(v)
        return out

    def _w_star_from_table(self, k, y):
        """
            Find w such that g(k, w) = y
        """
        idx  = bisect_left(self.gr_tables[k], float(y))
        xr   = self.gr_tables_mp[k]
        if idx <= 0:  return self.w_tables_mp[0]                  
        if idx >= len(self.w_tables_mp): return self.w_tables_mp[-1]            
        xL, xR = xr[idx-1], xr[idx]
        wL, wR = self.w_tables_mp[idx-1], self.w_tables_mp[idx]
        if xR == xL:
            return wL
        t = (mp.mpf(y) - xL) / (xR - xL)
        return wL + t * (wR - wL)

    def _w_star_from_table_np(self, k, y):
        """
        Float64 version of _w_star_from_table for vectorized paths.
        """
        xr = self.gr_tables[k]  # float64 strictly increasing
        if y <= xr[0]: return self.w_tables[0]
        if y >= xr[-1]: return self.w_tables[-1]
        idx = np.searchsorted(xr, y, side='left')
        xL, xR = xr[idx-1], xr[idx]
        wL, wR = self.w_tables[idx-1], self.w_tables[idx]
        if xR == xL:
            return wL
        t = (y - xL) / (xR - xL)
        return wL + t * (wR - wL)

    def _S_W(self, z):
        """
            Compute the survival function S_W(z) = Pr[W >= z] 
        """
        if z <= self.mp_neg_one: return mp.mpf(1)              
        if z >= self.mp_one: return mp.mpf(0)             

        u = (z + 1) / 2
        if u <= 0: return mp.mpf(1)
        if u >= 1: return mp.mpf(0)
        a = (self.T - 1)/2.0
        val_f64 = sp_betainc(float(a), float(a), 1 - float(u))
        if not np.isfinite(val_f64):
            return mp.mpf(1) - mp.mpf(u)
        return mp.mpf(val_f64)

    def _S_W_vec(self, z_arr):
        """
        Vectorized survival function on float64 inputs. Returns float64 array.
        """
        z = np.asarray(z_arr, dtype=np.float64)
        u = (z + 1.0) / 2.0
        a = (float(self.T) - 1.0) / 2.0

        out = np.empty_like(u, dtype=np.float64)
        out[u <= 0.0] = 1.0
        out[u >= 1.0] = 0.0
        mask = (u > 0.0) & (u < 1.0)
        if np.any(mask):
            out[mask] = sp_betainc(a, a, 1.0 - u[mask])
        return out
    
    def _build_radial_nodes_two_panel(self, *, K, tau=None, ratio=0.75, alpha_tail=0.0):
        """
        Two-panel Gamma(kappa,1) quadrature for t in [0, ∞):
        Panel 0: [0, tau]     via Gauss–Legendre on x∈[0,1], t = tau*x
        Panel 1: [tau, ∞)     via Gauss–Laguerre on u∈[0,∞),  t = tau + u

        Returns:
            rk: list[mp.mpf]   radii r = (t/beta)^(1/p)
            wk: list[mp.mpf]   weights that already include Gamma normalization (sum≈1)
        """
        kappa = (self.alpha + 1) / self.p
        Gk    = mp.gamma(kappa)

        # Split point and allocation
        tau   = mp.mpf(max(1.0, float(kappa))) if tau is None else mp.mpf(tau)
        K0    = max(16, int(ratio * K))   # near-zero panel nodes
        K1    = max(16, K - K0)           # tail panel nodes

        # ---- Panel 0: [0, tau] with Gauss–Legendre on x∈[0,1], t = tau*x ----
        # leggauss gives float64 nodes u∈[-1,1] and weights; map to x∈[0,1]
        u_gl, w_gl = leggauss(K0)
        x0 = [ (mp.mpf(ui) + 1) / 2 for ui in u_gl ]     # mpf in [0,1]
        w0 = [  mp.mpf(wi) / 2       for wi in w_gl ]    # mpf, integrates on [0,1]

        t0  = [ tau * xi for xi in x0 ]                  # t ∈ [0, tau]
        # Panel-0 weights from: (tau^kappa/Gamma(kappa)) ∫ x^{kappa-1} e^{-tau x} F(tau x) dx
        wk0 = [ (tau * wi) * (ti**(kappa-1)) * mp.e**(-ti) / Gk
                for wi, ti in zip(w0, t0) ]
        rk0 = [ (ti / self.beta)**(1/self.p) for ti in t0 ]

        # ---- Panel 1: [tau, ∞) with Gauss–Laguerre on u∈[0,∞), t = tau + u ----
        # Use your API for Laguerre nodes/weights with weight u^{alpha_tail} e^{-u}
        u1, a1 = gauss_laguerre_mixed_mp(K=K1, alpha=float(alpha_tail),dps=max(80, self.precision),
                                        workers=self.workers)
        mp.mp.dps = self.precision

        t1  = [ tau + ui for ui in u1 ]                  # t ∈ [tau, ∞)
        # From: (e^{-tau}/Gamma(kappa)) ∫ e^{-u} (tau+u)^{kappa-1} F(tau+u) du
        # with general alpha_tail, Laguerre integrates u^{alpha_tail} e^{-u} g(u):
        # wk1 = a1 * u^{-alpha_tail} * e^{-tau} * (tau+u)^{kappa-1} / Gk
        wk1 = [ (mp.e**(-tau) * ai) * (ui**(-alpha_tail)) * (ti**(kappa-1)) / Gk
                for ai, ui, ti in zip(a1, u1, t1) ]
        rk1 = [ (ti / self.beta)**(1/self.p) for ti in t1 ]

        # ---- Merge and renormalize (tiny drift guard) ----
        rk = rk0 + rk1
        wk = wk0 + wk1
        S  = mp.fsum(wk)
        if S != 0:
            wk = [w/S for w in wk]
        return rk, wk


    # -------- public API --------

    def cdf_vec(self, y_array, workers=None, use_persistent_pool=None):
        """
        Compute CDF values for an array of y values.
        
        Args:
            y_array: Array of y values to compute CDF for
            workers: Number of worker processes (None uses initialization default)
            use_persistent_pool: If True, reuse a pool tied to this discretizer across calls
                                 to avoid process spin-up cost. Defaults to self.use_persistent_pool.
        """
        if workers is None:
            workers = getattr(self, 'workers', 1)
        if use_persistent_pool is None:
            use_persistent_pool = getattr(self, 'use_persistent_pool', False)
        
        # For small arrays or single worker, use sequential processing
        if len(y_array) <= 1 or workers == 1:
            mp.mp.dps = self.precision
            out = []
            for y in y_array:
                y_mp = mp.mpf(y)
                acc = mp.mpf('0')
                for k in range(len(self.rk)):
                    w_star = self._w_star_from_table(k, y_mp)
                    acc += self.wk[k] * self._S_W(w_star)
                # clip tiny overshoots
                if acc < 0: acc = self.mp_zero
                if acc > 1: acc = self.mp_one
                out.append(acc)
            return out
        
        # Chunking approach: use many small chunks to reduce stragglers
        total = len(y_array)
        target_chunks = max(workers * 8, 32)
        chunk_size = max(32, min(1024, (total + target_chunks - 1) // target_chunks))
        chunks = [y_array[i:i + chunk_size] for i in range(0, total, chunk_size)]
        
        # Share the discretizer via initializer to avoid per-task pickling
        if use_persistent_pool:
            pool = self._ensure_pool(workers)
            chunk_results = pool.map(_process_y_chunk_cdf, chunks, chunksize=1)
        else:
            with Pool(processes=workers, initializer=_init_worker, initargs=(self,)) as pool:
                # Stream small chunks to improve load balancing
                chunk_results = pool.map(_process_y_chunk_cdf, chunks, chunksize=1)
        
        # Flatten the results back to original order
        out = []
        for chunk_result in chunk_results:
            out.extend(chunk_result)
        
        return out

    def cdf(self, y):
        return self.cdf_vec([y])[0]

    def enable_persistent_pool(self, workers=None):
        """
        Create (or reuse) a multiprocessing pool tied to this discretizer so repeated
        cdf_vec calls avoid process spin-up cost. Remember to disable it when done.
        """
        if workers is not None:
            self.workers = workers
        self.use_persistent_pool = True
        self._ensure_pool(self.workers)

    def disable_persistent_pool(self):
        """Close and join the persistent pool, if any."""
        self.use_persistent_pool = False
        self._close_pool()

    def _ensure_pool(self, workers=None):
        if workers is None:
            workers = getattr(self, 'workers', 1)
        if self._pool is None or workers != self._pool_workers:
            self._close_pool()
            self._pool_workers = workers
            self._pool = Pool(processes=workers, initializer=_init_worker, initargs=(self,))
        return self._pool

    def _close_pool(self):
        if self._pool is not None:
            self._pool.close()
            self._pool.join()
            self._pool = None

def _discretize_spherical_generalized_gamma_to_grid(T, alpha, beta, p, s, L, h, M, prec=50, workers=1, K=1024):
    """
    Discretize a 1D Normal(mean, std^2) to a pmf on grid points x_i = -L + i*h (i=0..M-1), 
    assigning each bin the probability mass over [x_i - h/2, x_i + h/2], 
    truncated to [-L, L], then normalized.
    """
    # Ensure an odd number of points so that x=0 is exactly the middle bin.
    M = int(M)
    assert M % 2 == 1 and M > 0, "M must be odd and positive"
    assert h > 0, "h must be positive"
    assert L > 0, "L must be positive"

    xs = -L + h * np.arange(M)

    # Bin edges (left/right), clamped to [-L, L]
    left_edges  = np.maximum(xs - 0.5 * h, -L)
    right_edges = np.minimum(xs + 0.5 * h,  L)

    # Use Discretize_SGG_PLRV class for efficient computation
    t_start = time.time()
    discretizer = Discretize_SGG_PLRV(
        T=T, alpha=alpha, beta=beta, p=p, s=s,
        K=K, precision=prec, w_grid_size=2049, workers=workers
    )
    t_end = time.time()
    print(f"[timing] Discretize_SGG_PLRV init took {t_end - t_start:.3f}s")
    
    # Compute CDF values using the discretizer
    t_start = time.time()
    edge_len = len(left_edges)
    # Deduplicate edges to avoid redundant CDF work when many edges repeat
    all_edges = np.concatenate([left_edges, right_edges])
    unique_edges, inverse_idx = np.unique(all_edges, return_inverse=True)

    cdf_unique = discretizer.cdf_vec(unique_edges.tolist(), use_persistent_pool=True)
    cdf_unique = np.array(cdf_unique, dtype=np.float64)
    cdf_edges = cdf_unique[inverse_idx]

    cdf_left = cdf_edges[:edge_len]
    cdf_right = cdf_edges[edge_len:]
    t_end = time.time()
    print(f"[timing] Discretize_SGG_PLRV cdf_vec took {t_end - t_start:.3f}s")

    pmf = (np.array(cdf_right) - np.array(cdf_left)).astype(np.float64)

    s = pmf.sum()
    assert s > 0, "Grid too small: all mass truncated. Increase L or h."
    pmf /= s  # renormalize after truncation
    return xs, pmf

class Discretize_SGG_PLRV_Q(Discretize_SGG_PLRV):
    """
    Discretize the privacy-loss RV under Q (the shifted distribution),
    using the same Gauss–Laguerre (or two-panel) quadrature over X~P.

    This class reuses all infrastructure from Discretize_SGG_PLRV and only
    changes the privacy-loss evaluation to use D_- = ||X - μ|| instead of
    D_+ = ||X + μ||.
    """
    def __init__(self, T, alpha, beta, p, s, K=128, precision=50, w_grid_size=2049, workers=50):
        super().__init__(T, alpha, beta, p, s, K, precision, w_grid_size, workers)

    def func_g(self, r, w):
        # g_Q(r,w) = -c1*log1p(-2 s w/r + s^2/r^2) + beta*((r^2 - 2 s w r + s^2)^{p/2} - r^p)
        s, p, beta, c1 = self.s, self.p, self.beta, self.c1

        ln_arg = -2 * s * w / r + (s**2) / (r**2)
        # avoid hitting log1p(-1) exactly
        if ln_arg <= -1:
            ln_arg = -1 + mp.mpf("1e-40")

        Dm2 = r*r - 2*s*w*r + s*s
        if Dm2 <= 0:
            # extremely degenerate; keep it safe
            Dm = mp.mpf("0")
        else:
            Dm = mp.sqrt(Dm2)

        return (-c1) * mp.log1p(ln_arg) + beta * (mp.power(Dm, p) - mp.power(r, p))

    
def _discretize_spherical_generalized_gamma_to_grid_under_Q(T, alpha, beta, p, s, L, h, M, prec=50, workers=1, K=1024):
    """
    Discretize the privacy-loss RV under Q on grid points x_i = -L + i*h (i=0..M-1),
    using the Q-side PLRV discretizer (Discretize_SGG_PLRV_Q).
    """
    M = int(M)
    assert M % 2 == 1 and M > 0, "M must be odd and positive"
    assert h > 0, "h must be positive"
    assert L > 0, "L must be positive"

    xs = -L + h * np.arange(M)

    left_edges  = np.maximum(xs - 0.5 * h, -L)
    right_edges = np.minimum(xs + 0.5 * h,  L)

    t_start = time.time()
    discretizer = Discretize_SGG_PLRV_Q(
        T=T, alpha=alpha, beta=beta, p=p, s=s,
        K=K, precision=prec, w_grid_size=2049, workers=workers
    )
    t_end = time.time()
    print(f"[timing] Discretize_SGG_PLRV_Q init took {t_end - t_start:.3f}s")

    t_start = time.time()
    edge_len = len(left_edges)

    all_edges = np.concatenate([left_edges, right_edges])
    unique_edges, inverse_idx = np.unique(all_edges, return_inverse=True)

    cdf_unique = discretizer.cdf_vec(unique_edges.tolist(), use_persistent_pool=True)
    cdf_unique = np.array(cdf_unique, dtype=np.float64)
    cdf_edges = cdf_unique[inverse_idx]

    cdf_left  = cdf_edges[:edge_len]
    cdf_right = cdf_edges[edge_len:]
    t_end = time.time()
    print(f"[timing] Discretize_SGG_PLRV_Q cdf_vec took {t_end - t_start:.3f}s")

    pmf = (cdf_right - cdf_left).astype(np.float64)
    ssum = pmf.sum()
    assert ssum > 0, "Grid too small: all mass truncated. Increase L or h."
    pmf /= ssum
    return xs, pmf


def _z_from_tail_mass(tail_mass):
    """
    Convert one-sided tail mass p to a z-score such that P[Z > z] = p for Z~N(0,1).
    Clamps extremely small/large inputs for stability.
    """
    tail_mass = float(tail_mass)
    tail_mass = min(max(tail_mass, 1e-20), 0.5)  # avoid under/overflow
    return float(np.sqrt(2.0) * erfcinv(2.0 * tail_mass))


def delta_from_single_prv(xs, pmf, eps):
    # xs: grid of L values (bin centers)
    # pmf: probabilities under P (summing to ~1)
    xs = np.asarray(xs, dtype=np.float64)
    pmf = np.asarray(pmf, dtype=np.float64)

    # integrand: (1 - exp(eps + L_code))_+
    vals = 1.0 - np.exp(eps + xs)
    vals[vals < 0.0] = 0.0

    return float(np.dot(pmf, vals))

def sgg_delta_via_fft_accounting(eps, T, alpha, beta, p, s, k, h=None, h_max=None, L=None, target_M=1<<16, prec=50, workers=1, tail_mass=1e-18, K=1024):
    # Base PRV parameters for Gaussian mechanism
    s = float(s)
    m = generalized_gamma_moments(alpha, beta, p, moment = 1, prec=prec)
    m2 = generalized_gamma_moments(alpha, beta, p, moment = 2, prec=prec)
    var = max(m2 - m**2, 0.0)
    std = np.sqrt(var)

    # Set parameters: L, the truncation half-width; h, the grid step; M, the number of grid points
    assert L is None or L > 0, "L must be positive."
    # Use a z-score derived from desired one-sided tail mass (default ~7σ for 1e-12)
    z = _z_from_tail_mass(tail_mass)
    mean_mag = abs(k * m)
    sd_sum = np.sqrt(k) * std
    L = max(abs(eps) + z * sd_sum, mean_mag + z * sd_sum) + 1.0 if L is None else L

    if h is None:
        h_base = (2.0 * L) / (target_M - 1)
        if h_max is not None:
            h = min(h_base, h_max)
        else:
            h = h_base
        h = max(h, 1e-12)
    M = int(round(2 * L / h))
    if M % 2 == 0:
        M += 1
    # Recompute h from finalized M to ensure xs exactly span [-L, L]
    h = 2.0 * L / (M - 1)
    print(f"L: {L}, h: {h}, M: {M}")

    # Discretize privacy loss random variable L_ggama using parallelized Discretize_SGG_PLRV
    xs,  py = _discretize_spherical_generalized_gamma_to_grid(T, alpha, beta, p, s, L, h, M, prec=prec, workers=workers, K=K)

    # k-fold privacy loss random variable composition via FFT with repeated squaring + cropping 
    py_k = _pmf_power_fft(py, k, M)
    delta = delta_from_single_prv(xs, py_k, eps)
    
    return max(0.0, min(1.0, float(delta)))


    
