import unittest
from dp_sgd.mironov_rdp_accounting import get_noise_multiplier, compute_adp_epsilon, rdp_to_adp
from dp_sgd.mironov_rdp_accounting.get_noise_multiplier import DEFAULT_ALPHAS

import numpy as np

from opacus.accountants.analysis import rdp as privacy_analysis



class TestMironovPrivacyAccounting(unittest.TestCase):
    def test_rdp_to_adp(self):
        rdp_epsilon = 1.0
        alpha = 2.0
        delta = 1e-5
        adp_epsilon = rdp_to_adp(rdp_epsilon, alpha, delta)
        expected_adp_epsilon = rdp_epsilon + (np.log(1 / delta) / (alpha - 1))
        self.assertAlmostEqual(adp_epsilon, expected_adp_epsilon)

    def test_compute_epsilon(self):
        rdp_epsilons = [0.5, 1.0, 1.5]
        alphas = [2.0, 3.0, 4.0]
        delta = 1e-5
        adp_epsilon = compute_adp_epsilon(rdp_epsilons, alphas, delta)
        expected_adp_epsilons = [
            rdp_to_adp(rdp_epsilons[i], alphas[i], delta) for i in range(len(rdp_epsilons))
        ]
        expected_min_adp_epsilon = min(expected_adp_epsilons)
        self.assertAlmostEqual(adp_epsilon, expected_min_adp_epsilon)

    def test_get_noise_multiplier(self):
        target_epsilon = 2.71
        target_delta = 1e-5
        sample_rate = 0.014
        steps = 4238
        noise_multiplier = get_noise_multiplier(
            target_epsilon=target_epsilon, 
            target_delta=target_delta, 
            sample_rate=sample_rate, 
            steps=steps
        )

        self.assertIsInstance(noise_multiplier, float)
        self.assertGreater(noise_multiplier, 0)

        rdp_epsilons = privacy_analysis.compute_rdp(
            q=sample_rate,
            noise_multiplier=noise_multiplier,
            steps=steps,
            orders=DEFAULT_ALPHAS,
        )
        eps = compute_adp_epsilon(rdp_epsilons, DEFAULT_ALPHAS, target_delta)

        self.assertLessEqual(eps, target_epsilon)
        self.assertGreaterEqual(eps, target_epsilon - 0.02)