import numpy as np
from typing import Tuple, List

from opacus.accountants.analysis import rdp as privacy_analysis


def rdp_to_adp(rdp_epsilon, alpha, delta):
    r"""
    Converts Rényi Differential Privacy (RDP) parameters to Approximate Differential Privacy (ADP) parameters.
    Uses the original conversion formula from Mironov (2017) that also applies to ex-post definitions (Ghazi et al. 2025).

    Args:
        rdp_epsilon: The RDP epsilon value.
        alpha: The order of the Rényi divergence.
        delta: The target delta for ADP.

    Returns:
        The corresponding ADP epsilon value.
    """
    if alpha <= 1:
        raise ValueError("Alpha must be greater than 1 for RDP to ADP conversion.")

    adp_epsilon = rdp_epsilon + np.log(1 / delta) / (alpha - 1)
    return adp_epsilon


def compute_adp_epsilon(rdp_epsilons, alphas, delta, return_best_alpha=False) -> float | Tuple[float, float, float]:
    r"""
    Computes the ADP epsilon for a list of RDP epsilons and their corresponding alphas.

    Args:
        rdp_epsilons: A list of RDP epsilon values.
        alphas: A list of alpha values corresponding to the RDP epsilons.
        delta: The target delta for ADP.
        return_best_alpha: If True, also returns the alpha and RDP epsilon that gives the minimum ADP epsilon.
    
    Returns:
        The minimum ADP epsilon value computed from the RDP epsilons and alphas.
        Optionally also return the best alpha and RDP epsilon if return_best_alpha is True.
    """
    adp_epsilons = [rdp_to_adp(rdp_epsilon, alpha, delta) for rdp_epsilon, alpha in zip(rdp_epsilons, alphas)]
    if return_best_alpha:
        min_index = np.argmin(adp_epsilons)
        return adp_epsilons[min_index], alphas[min_index], rdp_epsilons[min_index]
    else:
        return min(adp_epsilons)


def compute_adp_epsilon_from_accountant(accountant, alphas, delta, return_best_alpha=False):
    r"""
    Computes the ADP epsilon from an RDP accountant.

    Args:
        accountant: An RDP accountant object storing history of mechanism applications.
        alphas: A list of alpha values for which to compute the ADP epsilon.
        delta: The target delta for ADP.

    Returns:
        The minimum ADP epsilon value computed from the accountant's RDP epsilons and the provided alphas.
    """

    # Adapted under Apache 2.0 license from (26.11.2025): 
    # https://github.com/meta-pytorch/opacus/blob/main/opacus/accountants/rdp.py
    rdp = sum(
        [
            privacy_analysis.compute_rdp(
                q=sample_rate,
                noise_multiplier=noise_multiplier,
                steps=num_steps,
                orders=alphas,
            )
            for (noise_multiplier, sample_rate, num_steps) in accountant.history
        ] # type: ignore
    ) # type: ignore

    return compute_adp_epsilon(rdp, alphas, delta, return_best_alpha=return_best_alpha)


def compute_adp_epsilon_from_accountant_with_threshold_check(
        accountant, alphas, delta, threshold_check_rdp_epsilons, return_best_alpha=False
        ) -> float:
    r"""
    Computes the ADP epsilon from an RDP accountant with threshold check RDP epsilons.

    Args:
        accountant: An RDP accountant object storing history of mechanism applications.
        alphas: A list of alpha values for which to compute the ADP epsilon.
        delta: The target delta for ADP.
        threshold_check_rdp_epsilons: A list of RDP epsilons corresponding to threshold checks.

    Returns:
        The minimum ADP epsilon value computed from the accountant's RDP epsilons and the provided alphas,
        accounting for the threshold check RDP epsilons.
    """

    # Adapted under Apache 2.0 license from (26.11.2025): 
    # https://github.com/meta-pytorch/opacus/blob/main/opacus/accountants/rdp.py
    rdp = sum(
        [
            privacy_analysis.compute_rdp(
                q=sample_rate,
                noise_multiplier=noise_multiplier,
                steps=num_steps,
                orders=alphas,
            )
            for (noise_multiplier, sample_rate, num_steps) in accountant.history
        ] # type: ignore
    ) # type: ignore
    max_rdp = np.maximum(rdp, threshold_check_rdp_epsilons)

    return compute_adp_epsilon(max_rdp, alphas, delta, return_best_alpha=return_best_alpha)