# pip install dp-accounting; here I use the version 0.5.0
import numpy as np
from scipy.special import erf, erfc
from scipy.signal import fftconvolve

def delta_Gaussian_mech_compositions_analytic(epsilon, sigma, k, s = 1.0):  
    """
    Given the privacy parameter epsilon, the standard deviation sigma of the Gaussian mechanism, and the sensitivity s,
    using the analytic formula, compute the privacy parameter delta of the Gaussian mechanism with k compositions.
    """
    m = (s**2) / (2.0 * sigma**2)
    v = (s**2) / (sigma**2)
    sd = np.sqrt(k * v)

    def tail(z):
        return 0.5 * erfc(z / np.sqrt(2.0))

    termP = tail((epsilon - k*m) / sd)
    termQ = tail((epsilon + k*m) / sd)
    return float(np.maximum(0.0, termP - np.exp(epsilon) * termQ))

def _normal_cdf(x, mean, std):
    z = (x - mean) / (std * np.sqrt(2.0))
    return 0.5 * (1.0 + erf(z))

def _discretize_normal_to_grid(mean, std, L, h, M):
    """
    Discretize a 1D Normal(mean, std^2) to a pmf on grid points x_i = -L + i*h (i=0..M-1), 
    assigning each bin the probability mass over [x_i - h/2, x_i + h/2], 
    truncated to [-L, L], then normalized.
    """
    # Ensure an odd number of points so that x=0 is exactly the middle bin.
    M = int(M)
    assert M % 2 == 1 and M > 0, "M must be odd and positive"
    assert h > 0, "h must be positive"
    assert L > 0, "L must be positive"

    xs = -L + h * np.arange(M)

    # Bin edges (left/right), clamped to [-L, L]
    left_edges  = np.maximum(xs - 0.5 * h, -L)
    right_edges = np.minimum(xs + 0.5 * h,  L)

    # Probability per bin via CDF differences
    cdf_left  = np.vectorize(_normal_cdf)(left_edges,  mean, std)
    cdf_right = np.vectorize(_normal_cdf)(right_edges, mean, std)
    pmf = (cdf_right - cdf_left).astype(np.float64)

    s = pmf.sum()
    assert s > 0, "Grid too small: all mass truncated. Increase L or h."
    pmf /= s  # renormalize after truncation
    return xs, pmf

# Linear convolution via FFT
def _fft_linear_convolve(a, b):
    """ Linear convolution via FFT """
    c = fftconvolve(a, b, mode="full")
    c = np.maximum(c, 0.0)
    c /= c.sum()

    return c

def _crop_center(c, M):
    """
    After convolving two M-length pmfs (on [-L, L]), we get length 2M-1,
    representing [-2L, 2L] with step h. Crop the center window back to [-L, L].
    """
    n = len(c)
    if n < M:
        raise ValueError("Cannot crop: result shorter than target length.")
    start = (n - M) // 2
    end = start + M
    r = c[start:end]
    s = r.sum()
    if s > 0:
        r /= s
    return r

def _pmf_power_fft(base_pmf, k, M):
    """
    Compose k i.i.d. copies via FFT using exponentiation by squaring + cropping.
    All pmfs are defined on the same fixed grid length M ([-L, L]).
    """
    # Dirac at 0 on the same grid: mass 1 at center bin.
    r = np.zeros_like(base_pmf)
    r[M // 2] = 1.0
    p = base_pmf.copy()

    kk = k
    while kk > 0:
        if kk & 1:
            r = _crop_center(_fft_linear_convolve(r, p), M)
        kk >>= 1
        if kk:
            p = _crop_center(_fft_linear_convolve(p, p), M)
    return r

def _tail_prob_from_pmf(xs, pmf, h, eps):
    """
    Compute Pr[Z > eps] from discretized pmf (bins centered at xs[i]) with bin width h.
    """
    # idx = np.searchsorted(xs, eps - 0.5*h, side='right')
    # cdf_eps = pmf[:idx].sum() if idx > 0 else 0.0
    idx = np.searchsorted(xs, eps, side='right')
    if idx == 0:
        return 1.0
    # full mass below bins strictly left of idx-1
    cdf_eps = pmf[:idx-1].sum()
    # partial mass from the bin containing eps, assuming uniform density within bin
    left_edge = xs[idx-1] - 0.5 * h
    frac = min(1.0, max(0.0, (eps - left_edge) / h))
    cdf_eps += pmf[idx-1] * frac
    return max(0.0, 1.0 - cdf_eps)

def gaussian_delta_via_fft_accounting(eps, sigma, k, s=1.0, h=None, L=None, target_M=1<<16):
    """
    Numerical (FFT-based) composition of Gaussian PRVs from 'Numerical Composition of DP' (Gopi, Lee, Wutschitz). Returns the smallest delta for a fixed epsilon.

    Parameters
    ----------
    eps : Privacy parameter epsilon to evaluate.
    sigma : Gaussian noise std (for N(0, sigma^2 I_d) vs N(mu, sigma^2 I_d)).
    k : Number of i.i.d. compositions (independent releases).
    s : L2-norm of the mean shift ||mu||_2 (sensitivity). Your case is 1.0.
    target_M : Target number of grid points (odd). Controls memory/time.
    h : Grid step. If None, chosen adaptively to keep grid size ~ target_M.
    L : Truncation half-width. If None, choose heuristically from k, sigma, s, eps.
    """
    # Base PRV parameters for Gaussian mechanism
    s = float(s)
    m = (s**2) / (2.0 * sigma**2)      # mean of X; Y has -m
    v = (s**2) / (sigma**2)            # variance of both X and Y
    std = np.sqrt(v)

    # Set parameters: L, the truncation half-width; h, the grid step; M, the number of grid points
    assert L is None or L > 0, "L must be positive."
    mean_mag = abs(k * m)
    sd_sum = np.sqrt(k) * std
    L = max(abs(eps) + 6.0 * sd_sum, mean_mag + 6.0 * sd_sum) + 1.0 if L is None else L

    h = max((2.0 * L) / (target_M - 1), 1e-8) if h is None else h
    M = int(round(2 * L / h))
    if M % 2 == 0:
        M += 1
    # Recompute h from finalized M to ensure xs exactly span [-L, L]
    h = 2.0 * L / (M - 1)
    print(f"L: {L}, h: {h}, M: {M}")

    # Discretize base PRVs X and Y
    xs, px = _discretize_normal_to_grid(+m, std, L, h, M)
    _,  py = _discretize_normal_to_grid(-m, std, L, h, M)

    # k-fold composition via FFT with repeated squaring + cropping 
    px_k = _pmf_power_fft(px, k, M)
    py_k = _pmf_power_fft(py, k, M)

    # Evaluate delta(eps) = P[X_sum > eps] - e^eps P[Y_sum > eps]
    tail_x = _tail_prob_from_pmf(xs, px_k, h, eps)
    tail_y = _tail_prob_from_pmf(xs, py_k, h, eps)
    delta = tail_x - np.exp(eps) * tail_y
    return max(0.0, min(1.0, float(delta)))