# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from opacus.accountants import create_accountant
from opacus.accountants.analysis import rdp as privacy_analysis

from .mironov_rdp_to_adp_conversion import compute_adp_epsilon

MAX_SIGMA = 1e6

# Source: https://github.com/meta-pytorch/opacus/blob/main/opacus/accountants/rdp.py
DEFAULT_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))


# Source: https://github.com/meta-pytorch/opacus/blob/main/opacus/accountants/utils.py
# Modified to use original conversion to (epsilon, delta) DP (26.11.2025)
def get_noise_multiplier(
    *,
    target_epsilon: float,
    target_delta: float,
    sample_rate: float,
    epochs: Optional[int] = None,
    steps: Optional[int] = None,
    accountant: str = "rdp",
    epsilon_tolerance: float = 0.01,
    **kwargs,
) -> float:
    r"""
    Computes the noise level sigma to reach a total budget of (target_epsilon, target_delta)
    at the end of epochs, with a given sample_rate


    Args:
        target_epsilon: the privacy budget's epsilon
        target_delta: the privacy budget's delta
        sample_rate: the sampling rate (usually batch_size / n_data)
        epochs: the number of epochs to run
        steps: number of steps to run
        accountant: accounting mechanism used to estimate epsilon
        epsilon_tolerance: precision for the binary search
    Returns:
        The noise level sigma to ensure privacy budget of (target_epsilon, target_delta)
    """
    if (steps is None) == (epochs is None):
        raise ValueError(
            "get_noise_multiplier takes as input EITHER a number of steps or a number of epochs"
        )
    if steps is None:
        steps = int(epochs / sample_rate)

    if accountant != "rdp":
        raise ValueError("Only 'rdp' accountant is supported in this function.")

    eps_high = float("inf")
    accountant = create_accountant(mechanism=accountant)

    if "alphas" in kwargs:
        alphas = kwargs["alphas"]
    else:
        alphas = DEFAULT_ALPHAS

    sigma_low, sigma_high = 0, 10
    while eps_high > target_epsilon:
        sigma_high = 2 * sigma_high
        # accountant.history = [(sigma_high, sample_rate, steps)]
        # eps_high = accountant.get_epsilon(delta=target_delta, **kwargs)
        rdp_epsilons = privacy_analysis.compute_rdp(
            q=sample_rate,
            noise_multiplier=sigma_high,
            steps=steps,
            orders=alphas,
        )
        eps_high = compute_adp_epsilon(rdp_epsilons, alphas, target_delta)

        if sigma_high > MAX_SIGMA:
            raise ValueError("The privacy budget is too low.")

    while target_epsilon - eps_high > epsilon_tolerance:
        sigma = (sigma_low + sigma_high) / 2
        # accountant.history = [(sigma, sample_rate, steps)]
        # eps = accountant.get_epsilon(delta=target_delta, **kwargs)
        rdp_epsilons = privacy_analysis.compute_rdp(
            q=sample_rate,
            noise_multiplier=sigma,
            steps=steps,
            orders=alphas,
        )
        eps = compute_adp_epsilon(rdp_epsilons, alphas, target_delta)

        if eps < target_epsilon:
            sigma_high = sigma
            eps_high = eps
        else:
            sigma_low = sigma

    return sigma_high