import unittest
from src.brownian_mechanism import BrownianMechanism
from src.brownian_mechanism import calculate_mechanism_scales_for_svt_gauss_laplace
from src.brownian_mechanism import calculate_mechanism_scales_for_svt_bounded_length
from src.brownian_mechanism import calculate_mechanism_scales_for_svt_nonneg_utility
from src.brownian_mechanism import calculate_max_utility_variance_split_nonneg_utility
from src.brownian_mechanism import calculate_mechanism_scale_no_svt
import numpy as np

class TestBrownianMechanism(unittest.TestCase):
    def test_release_raw_value_increasing_epsilon(self):
        true_value = np.array([10.0, 20.0, 30.0])
        alpha = 2.0
        sensitivity = 1.0
        mechanism = BrownianMechanism(true_value, alpha, sensitivity)

        # First release with epsilon=1.0
        noisy_value1 = mechanism.release_raw_value(1.0)
        self.assertEqual(noisy_value1.shape, true_value.shape)

        # Second release with increased epsilon=2.0
        noisy_value2 = mechanism.release_raw_value(2.0)
        self.assertEqual(noisy_value2.shape, true_value.shape)

        # Ensure that the two noisy values are different
        self.assertFalse(np.array_equal(noisy_value1, noisy_value2))

    def test_release_raw_value_decreasing_epsilon(self):
        true_value = np.array([10.0, 20.0, 30.0])
        alpha = 2.0
        sensitivity = 1.0
        mechanism = BrownianMechanism(true_value, alpha, sensitivity)

        mechanism.release_raw_value(1.0)
        
        with self.assertRaises(ValueError):
            mechanism.release_raw_value(0.5)  # Decreasing epsilon should raise ValueError

    def test_release_value_precision_weighting(self):
        true_value = np.array([10.0, 20.0, 30.0])
        alpha = 2.0
        sensitivity = 1.0
        mechanism = BrownianMechanism(true_value, alpha, sensitivity)

        # Release values with increasing epsilons
        release1 = mechanism.release_value(1.0)
        release2 = mechanism.release_value(3.0)
        release3 = mechanism.release_value(4.0)
        release3 = mechanism.release_value(5.5)

        # Check that the final release is a weighted average of previous releases
        precision_sum = sum(1 / var for var in mechanism.noise_variances)
        precision_weighted_sum = sum(
            mechanism.released_noisy_raw_values[i] / mechanism.noise_variances[i]
            for i in range(len(mechanism.released_noisy_raw_values))
        )
        expected_final_release = precision_weighted_sum / precision_sum

        np.testing.assert_allclose(release3, expected_final_release, rtol=1e-5)


class TestPrivacyCalculations(unittest.TestCase):
    def test_calculate_mechanism_scales_gauss_laplace_case_1(self):
        utility_epsilon = 1.0
        utility_variance_split = 0.5
        alpha = 2.0
        utility_sensitivity = 1.0

        sigma_1, b = calculate_mechanism_scales_for_svt_gauss_laplace(
            utility_epsilon, utility_variance_split, alpha, utility_sensitivity
        )

        self.assertGreater(sigma_1, 0)
        self.assertGreater(b, 0)

        variance_1 = sigma_1**2
        variance_2 = 2 * b**2
        total_variance = variance_1 + variance_2
        self.assertAlmostEqual(variance_1 / total_variance, utility_variance_split)

        epsilon_1 = alpha * utility_sensitivity**2 / (2 * sigma_1**2)
        epsilon_2 = 2 * utility_sensitivity / b
        total_epsilon = epsilon_1 + epsilon_2
        self.assertAlmostEqual(total_epsilon, utility_epsilon, places=5)

    def test_calculate_mechanism_scales_gauss_laplace_case_2(self):
        utility_epsilon = 0.7
        utility_variance_split = 0.74
        alpha = 3.2
        utility_sensitivity = 1.2

        sigma_1, b = calculate_mechanism_scales_for_svt_gauss_laplace(
            utility_epsilon, utility_variance_split, alpha, utility_sensitivity
        )

        self.assertGreater(sigma_1, 0)
        self.assertGreater(b, 0)

        variance_1 = sigma_1**2
        variance_2 = 2 * b**2
        total_variance = variance_1 + variance_2
        self.assertAlmostEqual(variance_1 / total_variance, utility_variance_split)

        epsilon_1 = alpha * utility_sensitivity**2 / (2 * sigma_1**2)
        epsilon_2 = 2 * utility_sensitivity / b
        total_epsilon = epsilon_1 + epsilon_2
        self.assertAlmostEqual(total_epsilon, utility_epsilon, places=5)

    def test_calculate_mechanism_scales_bounded_length_case_1(self):
        utility_epsilon = 2.0
        utility_variance_split = 0.5
        alpha = 2.0
        utility_sensitivity = 1.0
        max_releases = 4

        sigma_1, sigma_2 = calculate_mechanism_scales_for_svt_bounded_length(
            utility_epsilon, utility_variance_split, alpha, utility_sensitivity,
            max_releases
        )

        self.assertGreater(sigma_1, 0)
        self.assertGreater(sigma_2, 0)

        variance_1 = sigma_1**2
        variance_2 = sigma_2**2
        total_variance = variance_1 + variance_2
        self.assertAlmostEqual(variance_1 / total_variance, utility_variance_split, places=5)

        epsilon_1 = alpha * utility_sensitivity**2 / (2 * sigma_1**2)
        epsilon_2 = alpha * (2 * utility_sensitivity)**2 / (2 * sigma_2**2)
        epsilon_for_bounded_length = np.log(max_releases + 1) / (alpha - 1)
        total_epsilon = epsilon_1 + epsilon_2
        self.assertAlmostEqual(total_epsilon, utility_epsilon - epsilon_for_bounded_length)

    def test_calculate_mechanism_scales_bounded_length_case_2(self):
        utility_epsilon = 3.9
        utility_variance_split = 0.76
        alpha = 4.3
        utility_sensitivity = 1.1
        max_releases = 6

        sigma_1, sigma_2 = calculate_mechanism_scales_for_svt_bounded_length(
            utility_epsilon, utility_variance_split, alpha, utility_sensitivity,
            max_releases
        )

        self.assertGreater(sigma_1, 0)
        self.assertGreater(sigma_2, 0)

        variance_1 = sigma_1**2
        variance_2 = sigma_2**2
        total_variance = variance_1 + variance_2
        self.assertAlmostEqual(variance_1 / total_variance, utility_variance_split, places=5)

        epsilon_1 = alpha * utility_sensitivity**2 / (2 * sigma_1**2)
        epsilon_2 = alpha * (2 * utility_sensitivity)**2 / (2 * sigma_2**2)
        epsilon_for_bounded_length = np.log(max_releases + 1) / (alpha - 1)
        total_epsilon = epsilon_1 + epsilon_2
        self.assertAlmostEqual(total_epsilon, utility_epsilon - epsilon_for_bounded_length)


    def test_calculate_utility_variance_split(self):
        utility_variance_split = calculate_max_utility_variance_split_nonneg_utility()
        t_p = (1 - utility_variance_split) / utility_variance_split
        self.assertAlmostEqual(t_p, 3**0.5)


    def test_calculate_mechanism_scales_nonneg_utility_case_1(self):
        utility_epsilon = 2.0
        utility_variance_split = 0.3
        alpha = 2.0
        utility_sensitivity = 1.0
        utility_threshold = 0.8

        sigma_1, sigma_2 = calculate_mechanism_scales_for_svt_nonneg_utility(
            utility_epsilon, utility_variance_split, alpha, utility_sensitivity,
            utility_threshold
        )

        self.assertGreater(sigma_1, 0)
        self.assertGreater(sigma_2, 0)

        variance_1 = sigma_1**2
        variance_2 = sigma_2**2
        total_variance = variance_1 + variance_2
        self.assertAlmostEqual(variance_1 / total_variance, utility_variance_split, places=5)

        epsilon_1 = alpha * utility_sensitivity**2 / (sigma_1**2) # Chech not dividing by 2
        epsilon_2 = alpha * (2 * utility_sensitivity)**2 / (2 * sigma_2**2)
        term1 = 1 + 9 * utility_threshold**2 / sigma_1**2
        term2 = np.exp(utility_threshold**2 / sigma_1**2)
        epsilon_3 = np.log(1 + 2 * 3**0.5 * np.pi * term1 * term2) / (2 * (alpha - 1))
        total_epsilon = epsilon_1 + epsilon_2 + epsilon_3
        self.assertAlmostEqual(total_epsilon, utility_epsilon)


    def test_calculate_mechanism_scales_nonneg_utility_case_2(self):
        utility_epsilon = 1.1
        utility_variance_split = 0.35
        alpha = 4.6
        utility_sensitivity = 1 / 14637
        utility_threshold = 0.74

        sigma_1, sigma_2 = calculate_mechanism_scales_for_svt_nonneg_utility(
            utility_epsilon, utility_variance_split, alpha, utility_sensitivity,
            utility_threshold
        )

        self.assertGreater(sigma_1, 0)
        self.assertGreater(sigma_2, 0)

        variance_1 = sigma_1**2
        variance_2 = sigma_2**2
        total_variance = variance_1 + variance_2
        self.assertAlmostEqual(variance_1 / total_variance, utility_variance_split, places=5)

        epsilon_1 = alpha * utility_sensitivity**2 / (sigma_1**2) # Check not dividing by 2
        epsilon_2 = alpha * (2 * utility_sensitivity)**2 / (2 * sigma_2**2)
        term1 = 1 + 9 * utility_threshold**2 / sigma_1**2
        term2 = np.exp(utility_threshold**2 / sigma_1**2)
        epsilon_3 = np.log(1 + 2 * 3**0.5 * np.pi * term1 * term2) / (2 * (alpha - 1))
        total_epsilon = epsilon_1 + epsilon_2 + epsilon_3
        self.assertAlmostEqual(total_epsilon, utility_epsilon)


    def test_calculate_mechanism_scales_nonneg_utility_forbids_incorrect_variance_split(self):
        utility_epsilon = 1.4
        utility_variance_split = 0.37
        alpha = 4.6
        utility_sensitivity = 5.3
        utility_threshold = 0.74

        with self.assertRaises(ValueError):
            sigma_1, sigma_2 = calculate_mechanism_scales_for_svt_nonneg_utility(
                utility_epsilon, utility_variance_split, alpha, utility_sensitivity,
                utility_threshold
            )


    def test_calculate_mechanism_scale_no_svt(self):
        utility_epsilon = 1.4
        alpha = 4.6
        utility_sensitivity = 5.3
        max_release_count = 5

        sigma = calculate_mechanism_scale_no_svt(utility_epsilon, utility_sensitivity, max_release_count, alpha)

        total_epsilon = (max_release_count - 1) * alpha * utility_sensitivity**2 / (2 * sigma**2)
        self.assertAlmostEqual(total_epsilon, utility_epsilon)