#!/usr/bin/env python3
"""
Verify the Gaussian tangent-test conditions in Lemma (gauss-exists-delta).

Given (eps, delta, s), define g(u) = delta_Gauss(u)(eps) for additive Gaussian noise
with per-coordinate variance u = sigma^2 and L2-sensitivity s.

Compute u0 such that g(u0)=delta, build the tangent line ell(u) at u0, and verify:
  (1) g is convex on [u0, ∞)  (numerically: g'(u) nondecreasing on [u0, u_max])
  (2) g(u) >= ell(u) for all u in [0, u0] (numerically on a grid + check u=0 limit)

Run:
  python verify_gauss_tangent.py --eps 1.0 --delta 1e-6 --s 1.0
"""

import argparse
import math
import numpy as np

from scipy.special import log_ndtr  # log Phi(x), numerically stable


SQRT_2PI = math.sqrt(2.0 * math.pi)


def _phi(x: np.ndarray) -> np.ndarray:
    """Standard normal pdf."""
    return np.exp(-0.5 * x * x) / SQRT_2PI


def g_gauss(u: np.ndarray, eps: float, s: float) -> np.ndarray:
    """
    Gaussian privacy function g(u) = δ_G(ε) for X~N(0,u I), Y=X+μ, ||μ||=s.
    Uses stable log-space computation:
      g(u) = Phi(a) - exp(eps)*Phi(b),
      a = -(eps*sqrt(u))/s + s/(2*sqrt(u)),
      b = -(eps*sqrt(u))/s - s/(2*sqrt(u)).
    """
    u = np.asarray(u, dtype=float)
    u = np.maximum(u, 1e-300)
    r = np.sqrt(u)

    a = -(eps * r) / s + s / (2.0 * r)
    b = -(eps * r) / s - s / (2.0 * r)

    logA = log_ndtr(a)          # log Phi(a)
    logB = log_ndtr(b) + eps    # log (exp(eps)*Phi(b))

    # Stable difference exp(logA) - exp(logB)
    out = np.empty_like(u)
    mask = logA >= logB
    # when logA >= logB: exp(logA)*(1 - exp(logB-logA)) = exp(logA)*(-expm1(logB-logA))
    out[mask] = np.exp(logA[mask]) * (-np.expm1(logB[mask] - logA[mask]))
    # when logA < logB: exp(logB)*(exp(logA-logB) - 1) = exp(logB)*expm1(logA-logB)
    out[~mask] = np.exp(logB[~mask]) * (np.expm1(logA[~mask] - logB[~mask]))

    # Numerical guard: DP delta should be in [0,1]
    return np.clip(out, 0.0, 1.0)


def gprime_gauss(u: np.ndarray, eps: float, s: float) -> np.ndarray:
    """
    Derivative g'(u) computed analytically from:
      g(u)=Phi(a(u)) - exp(eps)*Phi(b(u))
    where a(u), b(u) as in g_gauss.
    """
    u = np.asarray(u, dtype=float)
    u = np.maximum(u, 1e-300)
    r = np.sqrt(u)

    a = -(eps * r) / s + s / (2.0 * r)
    b = -(eps * r) / s - s / (2.0 * r)

    # dr/du = 1/(2r)
    # da/du = (-eps/s)*dr/du + (s/2)*d(1/r)/du with d(1/r)/du = -1/(2 r^3)
    da = -eps / (2.0 * s * r) - s / (4.0 * r**3)
    db = -eps / (2.0 * s * r) + s / (4.0 * r**3)

    return _phi(a) * da - math.exp(eps) * _phi(b) * db


def find_u0_for_delta(eps: float, delta: float, s: float,
                      u_lo: float = 1e-24, u_hi: float = 1.0,
                      max_expand: int = 200, tol: float = 1e-12) -> float:
    """
    Find u0 such that g(u0)=delta by bracketing + bisection.
    Assumes g is (strictly) decreasing in u.
    """
    if not (0.0 < delta < 1.0):
        raise ValueError("delta must be in (0,1).")

    # Scale initial bracket by s^2 to be sensible across s
    u_lo = max(u_lo * (s**2), 1e-300)
    u_hi = max(u_hi * (s**2), u_lo * 10.0)

    g_lo = float(g_gauss(u_lo, eps, s))
    g_hi = float(g_gauss(u_hi, eps, s))

    # We need g(u_lo) >= delta >= g(u_hi).
    # If u_lo already too large, shrink it.
    shrink_iter = 0
    while g_lo < delta and shrink_iter < max_expand:
        u_lo *= 0.1
        u_lo = max(u_lo, 1e-300)
        g_lo = float(g_gauss(u_lo, eps, s))
        shrink_iter += 1

    # Expand u_hi until g(u_hi) <= delta
    expand_iter = 0
    while g_hi > delta and expand_iter < max_expand:
        u_hi *= 2.0
        g_hi = float(g_gauss(u_hi, eps, s))
        expand_iter += 1

    if not (g_lo >= delta >= g_hi):
        raise RuntimeError(
            f"Failed to bracket root: g(u_lo)={g_lo:.3e}, g(u_hi)={g_hi:.3e}, delta={delta:.3e}.\n"
            "Try adjusting u_lo/u_hi or max_expand."
        )

    # Bisection
    for _ in range(200):
        u_mid = 0.5 * (u_lo + u_hi)
        g_mid = float(g_gauss(u_mid, eps, s))
        if abs(g_mid - delta) <= tol * max(1.0, delta):
            return u_mid
        if g_mid > delta:
            u_lo = u_mid
        else:
            u_hi = u_mid

    return 0.5 * (u_lo + u_hi)


def verify_tangent_test(eps: float, delta: float, s: float,
                        n_left: int = 4000, n_right: int = 4000,
                        tol_support: float = 1e-12, tol_convex: float = 1e-10):
    """
    Verify conditions of Lemma gauss-exists-delta numerically.
    Returns diagnostics dict.
    """
    u0 = find_u0_for_delta(eps, delta, s)
    g0 = float(g_gauss(u0, eps, s))
    gp0 = float(gprime_gauss(u0, eps, s))

    # Tangent line ell(u) = g(u0) + g'(u0)*(u-u0)
    def ell(u):
        return g0 + gp0 * (u - u0)

    # Left-side support check on [0, u0]
    u_min = max(u0 * 1e-18, 1e-300)
    u_left = np.concatenate([
        np.array([0.0]),
        np.geomspace(u_min, u0, num=n_left)
    ])
    g_left = g_gauss(np.maximum(u_left, 1e-300), eps, s)
    # Use the correct limit g(0)=1 (mechanism reveals exact answer)
    g_left[0] = 1.0
    h_left = g_left - ell(u_left)
    min_h_left = float(np.min(h_left))

    # Choose u_max for tail convexity check.
    # We grow u_max until g(u_max) is extremely small, then check on [u0, u_max].
    target = min(1e-16, max(delta * 1e-6, 1e-20))
    u_max = u0
    for _ in range(200):
        if float(g_gauss(u_max, eps, s)) <= target:
            break
        u_max *= 2.0

    # Tail convexity check: g' should be nondecreasing on [u0, u_max]
    u_right = np.geomspace(u0, u_max, num=n_right)
    gp_right = gprime_gauss(u_right, eps, s)
    diffs = np.diff(gp_right)
    min_diff_gp = float(np.min(diffs))

    # Optional sanity: on a convex tail, g(u) lies above its tangent at u0 as well
    g_right = g_gauss(u_right, eps, s)
    h_right = g_right - ell(u_right)
    min_h_right = float(np.min(h_right))

    ok_left = (min_h_left >= -tol_support)
    ok_convex = (min_diff_gp >= -tol_convex)  # g' nondecreasing
    ok = ok_left and ok_convex

    return {
        "ok": ok,
        "eps": eps,
        "delta": delta,
        "s": s,
        "u0": u0,
        "g(u0)": g0,
        "g'(u0)": gp0,
        "u_max": u_max,
        "min_h_left": min_h_left,
        "min_h_right": min_h_right,
        "min_diff_gprime": min_diff_gp,
        "tol_support": tol_support,
        "tol_convex": tol_convex,
    }


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--eps", type=float, required=True)
    ap.add_argument("--delta", type=float, required=True)
    ap.add_argument("--s", type=float, default=1.0)
    ap.add_argument("--n_left", type=int, default=4000)
    ap.add_argument("--n_right", type=int, default=4000)
    ap.add_argument("--tol_support", type=float, default=1e-12)
    ap.add_argument("--tol_convex", type=float, default=1e-10)
    args = ap.parse_args()

    diag = verify_tangent_test(
        eps=args.eps,
        delta=args.delta,
        s=args.s,
        n_left=args.n_left,
        n_right=args.n_right,
        tol_support=args.tol_support,
        tol_convex=args.tol_convex,
    )

    print("=== Gaussian tangent-test verification ===")
    print(f"eps={diag['eps']}, delta={diag['delta']}, s={diag['s']}")
    print(f"u0 (solve g(u0)=delta): {diag['u0']:.16e}")
    print(f"g(u0): {diag['g(u0)']:.16e}")
    gprime_u0 = diag["g'(u0)"]
    print(f"g'(u0): {gprime_u0:.16e}")
    print(f"u_max (tail checked to): {diag['u_max']:.16e}")
    print("--- Left-side support (u in [0,u0]) ---")
    print(f"min_{'{'}u∈[0,u0]{'}'} [g(u)-ell(u)] = {diag['min_h_left']:.3e}  (tol={diag['tol_support']:.1e})")
    print("--- Tail convexity (u in [u0,u_max]) ---")
    print(f"min diff of g'(u) over grid = {diag['min_diff_gprime']:.3e}  (tol={diag['tol_convex']:.1e})")
    print(f"sanity: min_{'{'}u∈[u0,u_max]{'}'} [g(u)-ell(u)] = {diag['min_h_right']:.3e}")
    print("RESULT:", "PASS" if diag["ok"] else "FAIL")


if __name__ == "__main__":
    main()
