import numpy as np
from typing import List, Sequence, Tuple, Optional

# ---------------------- Privacy helpers ----------------------


def epsilon_schedule(
    total_eps: float, k: int, weights: Optional[Sequence[float]] = None
) -> np.ndarray:
    """
    Round-wise epsilon allocation.
    - Equal split if weights=None.
    - Otherwise proportional to weights (sum to total_eps).
    """
    if k <= 0:
        raise ValueError("k must be >= 1")
    total_eps = float(total_eps)
    if weights is None:
        return np.full(k, total_eps / k, dtype=float)
    w = np.asarray(weights, dtype=float)
    if w.size != k or np.any(w < 0) or w.sum() == 0:
        raise ValueError("weights must be non-negative length-k and not all zero")
    return total_eps * (w / w.sum())


def advanced_composition_eps(
    eps_list: Sequence[float], delta_prime: float
) -> Tuple[float, float]:
    """
    Advanced composition (Dwork–Roth Thm 3.20 style) for k pure-DP mechanisms ε_i.
    Result: (ε_total, δ_total) where δ_total = δ' (caller can add per-round δ if any).
    ε_total = sqrt(2 ln(1/δ') * sum ε_i^2) + sum ε_i * (e^{ε_i} - 1)
    """
    eps = np.asarray(eps_list, dtype=float)
    if np.any(eps < 0) or delta_prime <= 0 or delta_prime >= 1:
        raise ValueError("eps_list must be non-negative and 0<delta_prime<1")
    quad = np.sqrt(2.0 * np.log(1.0 / delta_prime) * np.sum(eps**2))
    lin = np.sum(eps * (np.exp(eps) - 1.0))
    return (quad + lin, float(delta_prime))


# ---------------------- Exponential Mechanism ----------------------


def exp_mech_pick_one(
    utilities: np.ndarray,
    epsilon: float,
    delta_u: float = 1.0,
    rng: Optional[np.random.Generator] = None,
) -> int:
    """
    Pick one index with probability ~ exp( (ε * u_i) / (2 Δu) ).
    Assumes utilities are already clipped so that sensitivity Δu holds.
    """
    if delta_u <= 0:
        raise ValueError("delta_u must be > 0")
    epsilon = float(epsilon)
    u = np.asarray(utilities, dtype=np.float64)
    if u.ndim != 1:
        raise ValueError("utilities must be 1-D")
    if rng is None:
        rng = np.random.default_rng()

    # Stable softmax
    scale = epsilon / (2.0 * delta_u)
    x = scale * (u - u.max())
    w = np.exp(x)
    p = w / (w.sum() + 1e-300)
    return int(rng.choice(len(u), p=p))


# ---------------------- DP top-k via peeling ----------------------


def dp_topk_exp_mech(
    utilities: Sequence[float],
    k: int,
    total_epsilon: float,
    delta_u: float = 1.0,
    eps_weights: Optional[Sequence[float]] = None,
    clip_range: Optional[Tuple[float, float]] = None,
    return_ordered: bool = True,
    seed: Optional[int] = None,
) -> Tuple[np.ndarray, dict]:
    """
    Top-k selection by peeling with Exponential Mechanism.
    - Good for composition: per-round eps schedule, returns accounting info.
    - If only the set matters, ignore order in evaluation.

    Args
    ----
    utilities: list/array of scores (float OK). Clip if needed to ensure Δu bound.
    k: number of items to select (1..d)
    total_epsilon: total privacy budget for this top-k selection
    delta_u: sensitivity of the utility function (L1/L∞-style bound you assume)
    eps_weights: optional length-k weights for per-round ε allocation
    clip_range: (low, high) to clip utilities before selection (recommended)
    return_ordered: True -> selection order (as chosen); False -> sorted by noisy preference not enforced
    seed: RNG seed

    Returns
    -------
    (idx, info) where
      - idx: np.ndarray of chosen indices (length k)
      - info: {"eps_round": np.ndarray, "eps_total_basic": float, "eps_total_adv": (ε, δ')}
    """
    rng = np.random.default_rng(seed)
    u = np.asarray(utilities, dtype=np.float64)
    d = u.size
    if not (1 <= k <= d):
        raise ValueError("k must be in [1, d]")
    if total_epsilon <= 0:
        raise ValueError("total_epsilon must be > 0")
    if clip_range is not None:
        lo, hi = clip_range
        u = np.clip(u, lo, hi)

    eps_round = epsilon_schedule(total_epsilon, k, eps_weights)
    remaining = list(range(d))
    chosen = []

    for t in range(k):
        u_rem = u[remaining]
        pick_local = exp_mech_pick_one(u_rem, eps_round[t], delta_u, rng)
        pick_global = remaining.pop(pick_local)
        chosen.append(pick_global)

    chosen = np.array(chosen, dtype=int)
    if not return_ordered:
        # just as a set (sorted indices)
        chosen = np.sort(chosen)

    # Accounting (basic sum and advanced bound with user-chosen δ')
    eps_basic = float(np.sum(eps_round))
    # δ' is configurable; expose a helper value for visibility (e.g., 1e-6)
    eps_adv, delta_prime = advanced_composition_eps(eps_round, delta_prime=1e-6)

    info = {
        "eps_round": eps_round,
        "eps_total_basic": eps_basic,
        "eps_total_adv": (eps_adv, delta_prime),
        "delta_u": float(delta_u),
        "k": int(k),
        "d": int(d),
    }
    return chosen, info


# ---------------------- Example ----------------------
if __name__ == "__main__":
    # Float utilities (already clipped in [0,1] -> Δu <= 1 if single-record effect)
    scores = np.array([0.83, 0.91, 0.77, 0.52, 0.21, 0.19], dtype=np.float64)
    k = 3
    eps_total = 4.2
    idx, info = dp_topk_exp_mech(
        scores, k, eps_total, delta_u=1.0, clip_range=(0.0, 1.0), seed=117
    )
    print("chosen:", [scores[i] for i in idx])
    print("eps_round:", info["eps_round"])
    print("eps_total_basic:", info["eps_total_basic"])
    print("eps_total_adv:", info["eps_total_adv"])
