import numpy as np
import scipy.optimize as opt
from collections.abc import Callable

from src.basic_mechanisms import gaussian_mechanism_rdp, gaussian_mechanism_rdp_variance, gaussian_mechanism_rdp_epsilon
from src.basic_mechanisms import laplace_mechanism, laplace_mechanism_variance, laplace_mechanism_scale

class BrownianMechanism:
    """Implements the Brownian mechanism with ex-post RDP guarantee using 
    precision-weighting.
    """
    def __init__(self, true_value: np.ndarray, alpha: float, sensitivity: float):
        """Initialize the Brownian mechanism with ex-post RDP guarantee.

        Args:
            true_value (np.ndarray): Private value to release.
            alpha (float): RDP alpha parameter. Must be > 1.
            sensitivity (float): Sensitivity of true_value. Must be positive.
        """

        if alpha <= 1:
            raise ValueError(f"Alpha must be greater than 1. Got alpha={alpha}.")
        if sensitivity <= 0:
            raise ValueError(f"Sensitivity must be positive. Got sensitivity={sensitivity}.")

        self.true_value = true_value
        self.alpha = alpha
        self.sensitivity = sensitivity
        self.previous_epsilon = None
        self.noise_variances = []
        self.released_noisy_raw_values = []

    def release_raw_value(self, new_epsilon: float) -> np.ndarray:
        """Release noisy value without precision-weighting.

        Args:
            new_epsilon (float): Epsilon to release with. Must be greater than previous epsilon.

        Returns:
            np.ndarray: Released unweighted noisy value.
        """
        if not (self.previous_epsilon is None or new_epsilon > self.previous_epsilon):
            raise ValueError("New epsilon must be greater than the previous epsilon.")

        epsilon_difference = new_epsilon - (self.previous_epsilon if self.previous_epsilon is not None else 0)
        self.previous_epsilon = new_epsilon

        noise_variance = gaussian_mechanism_rdp_variance(epsilon_difference, self.alpha, self.sensitivity)
        self.noise_variances.append(noise_variance)

        noise = np.random.normal(0, np.sqrt(noise_variance), size=self.true_value.shape)
        noisy_value = self.true_value + noise
        self.released_noisy_raw_values.append(noisy_value)

        return noisy_value

    def release_value(self, new_epsilon: float) -> np.ndarray:
        """Releases noisy value with precision-weighting.

        Args:
            new_epsilon (float): Epsilon to release with. Must be greater than previous epsilon.

        Returns:
            np.ndarray: Released precision-weighted noisy value.
        """
        _ = self.release_raw_value(new_epsilon)

        precision_sum = sum(1 / var for var in self.noise_variances)
        precision_weighted_value_sum = np.sum(np.stack([
            self.released_noisy_raw_values[i] / self.noise_variances[i]
            for i in range(len(self.released_noisy_raw_values)) 
        ], axis=0), axis=0)

        return precision_weighted_value_sum / precision_sum


def run_until_non_private_predicate(
        train_value: np.ndarray, stopping_condition: Callable[[np.ndarray], bool],
        epsilons: list[float], alpha: float, sensitivity: float
        ) -> tuple[list[np.ndarray], float, float]:
    """Run a Brownian mechanism until a non-private predicate of the output is 
    satisfied.

    The epsilon values to try are provided in a list of increasing values. 
    The epsilon values are tried in order until the predicate is satisfied.

    Args:
        train_value (np.ndarray): Private value to release.
        stopping_condition (Callable[[np.ndarray], bool]): Stopping condition predicate.
        epsilons (list[float]): Epsilons to try. Must be increasing.
        alpha (float): RDP alpha parameter.
        sensitivity (float): Sensitivity of train_value.

    Returns:
        tuple[list[np.ndarray], float, float]: Final released value, final epsilon, noise variance of final release
    """
    mechanism = BrownianMechanism(train_value, alpha, sensitivity)
    released_values = []
    current_epsilon = 0.0

    for epsilon in epsilons:
        release = mechanism.release_value(epsilon)
        released_values.append(release)
        current_epsilon = epsilon

        if stopping_condition(release):
            break

    return released_values, current_epsilon, gaussian_mechanism_rdp_variance(current_epsilon, alpha, sensitivity)


# Zhu and Wang (2020), Theorem 8, eq. (3)
def calculate_mechanism_scales_for_svt_gauss_laplace(
        utility_epsilon: float, utility_variance_split: float,
        alpha: float, utility_sensitivity: float
        ) -> tuple[float, float]:
    """Calculate noise scales for utility threshold check using SVT with 
    Gaussian + Laplace noise.

    Reference: Zhu and Wang "Improving Sparse Vector Technique with Renyi Differential Privacy"
    NeurIPS 2020, Theorem 8, eq. (3)

    Args:
        utility_epsilon (float): Epsilon for SVT.
        utility_variance_split (float): Fraction of noise variance allocated to the Gaussian mechanism.
        alpha (float): RDP alpha parameter.
        utility_sensitivity (float): Sensitivity of utility value.

    Returns:
        tuple[float, float]: Gaussian noise scale, Laplace noise scale
    """
    t_p = ((1 - utility_variance_split) / utility_variance_split)**0.5
    a1 = utility_epsilon
    a2 = -2 * 2**0.5 * utility_sensitivity / t_p
    a3 = -alpha * utility_sensitivity**2 / 2
    sigma_1 = (-a2 + (a2**2 - 4 * a1 * a3)**0.5) / (2 * a1)
    sigma_2 = t_p * sigma_1
    b = 1 / 2**0.5 * sigma_2

    return sigma_1, b


def get_noise_scales_svt_gauss_laplace(
        utility_epsilon: float, final_epsilon: float, utility_sensitivity: float, 
        main_sensitivity: float, alpha: float,
        utility_variance_split: float
        ) -> tuple[float, float, float]:
    """Get noise scales for SVT with Gaussian + Laplace noise.

    Args:
        utility_epsilon (float): Epsilon for threshold check.
        final_epsilon (float): Final epsilon of Brownian mechanism.
        utility_sensitivity (float): Sensitivity of utility value.
        main_sensitivity (float): Sensitivity of real_value.
        alpha (float): RDP alpha parameter.
        utility_variance_split (float): Fraction of noise variance allocated to the Gaussian mechanism.

    Returns:
        tuple[float, float, float]: Utility check Gaussian noise scale utility check Laplace noise scale, Brownian mechanism noise scale.
    """

    sigma_1, b = calculate_mechanism_scales_for_svt_gauss_laplace(
        utility_epsilon, utility_variance_split, alpha, utility_sensitivity
    )
    main_sigma = gaussian_mechanism_rdp_variance(final_epsilon, alpha, main_sensitivity)**0.5
    return sigma_1, b, main_sigma


def get_epsilons_with_noise_scales_svt_gauss_laplace(
        sigma_1: float, b: float, main_sigma: float, utility_sensitivity: float,
        main_sensitivity: float, alpha: float,
        ) -> tuple[float, float]:
    """Get epsilons for SVT with Gaussian + Laplace noise from noise scales.

    Args:
        sigma_1 (float): Noise scale for utility threshold mechanism.
        b (float): Noise scale for utility value mechanism.
        main_sigma (float): Noise scale for Brownian mechanism.
        utility_sensitivity (float): Sensitivity of utility value.
        main_sensitivity (float): Sensitivity of released value.
        max_release_count (int): Maximum number of releases. Must be at least 2.
        alpha (float): RDP alpha parameter.
        utility_variance_split (float): Fraction of noise variance allocated to the Gaussian mechanism.

    Returns:
        tuple[float, float, float]: Epsilon for utility threshold mechanism, epsilon for utility value mechanism, epsilon for Brownian mechanism.
    """

    utility_epsilon = gaussian_mechanism_rdp_epsilon(sigma_1, alpha, utility_sensitivity) + 2 * utility_sensitivity / b
    main_epsilon = gaussian_mechanism_rdp_epsilon(main_sigma, alpha, main_sensitivity)

    return utility_epsilon, main_epsilon


def run_until_private_utility_above_threshold_gauss_laplace(
        real_value: np.ndarray, validation_data: np.ndarray, 
        utility_function: Callable[[np.ndarray, np.ndarray, float, float], float], 
        utility_threshold: float,
        main_epsilons: list[float], utility_epsilon: float,
        utility_variance_split: float, alpha: float, main_sensitivity: float,
        utility_sensitivity: float, 
        ) -> tuple[list[np.ndarray], float]:
    """Run the Brownian mechanism until utility evaluated on private validation 
    crosses a threshold.

    The threshold is checked using SVT with Gaussian + Laplace noise.

    Arguments for utility_function are (released value, validation_data, release variance, release epsilon).

    Args:
        real_value (np.ndarray): Private value to release.
        validation_data (np.ndarray): Validation data to compute utility.
        utility_function (Callable[[np.ndarray, np.ndarray, float, float], float]): Function that computes utility.
        utility_threshold (float): Acceptable utility threshold.
        main_epsilons (list[float]): Epsilons for Brownian mechanism.
        utility_epsilon (float): Epsilon for utility check.
        utility_variance_split (float): Fraction of noise variance allocated to the Gaussian mechanism.
        alpha (float): RDP alpha parameter.
        main_sensitivity (float): Sensitivity of real_value.
        utility_sensitivity (float): Sensitivity of utility.

    Returns:
        tuple[list[np.ndarray], float]: List of released values, final epsilon.
    """

    sigma_1, b = calculate_mechanism_scales_for_svt_gauss_laplace(
        utility_epsilon, utility_variance_split, alpha, utility_sensitivity
    )

    utility_threshold_mechanism = lambda x: x + np.random.normal(0, sigma_1)
    utility_value_mechanism = lambda x: x + np.random.laplace(0, b)

    return _run_until_private_utility_above_threshold_impl(
        real_value, validation_data, utility_function, utility_threshold,
        utility_threshold_mechanism, utility_value_mechanism,
        main_epsilons, alpha, main_sensitivity
    )


# Zhu and Wang (2020), Remark (Bounded-length SVT)
# Not used due to large overhead in privacy cost.
def calculate_mechanism_scales_for_svt_bounded_length(
        utility_epsilon: float, utility_variance_split: float,
        alpha: float, utility_sensitivity: float, max_releases: int
        ) -> tuple[float, float]:
    """Calculate noise scales for utility threshold check using SVT with bounded
    length.

    Reference: Zhu and Wang "Improving Sparse Vector Technique with Renyi Differential Privacy"
    NeurIPS 2020, Remark (Bounded-length SVT)

    Args:
        utility_epsilon (float): Epsilon for SVT.
        utility_variance_split (float): Fraction of noise variance allocated to the Gaussian mechanism.
        alpha (float): RDP alpha parameter.
        utility_sensitivity (float): Sensitivity of utility value.
        max_releases (int): Maximum number of releases from SVT.

    Raises:
        ValueError: utility_epsilon must be large enough to accommodate overhead
        from the bounded length privacy accounting.

    Returns:
        tuple[float, float]: Threshold noise scale, value noise scale
    """
    epsilon_for_bounded_length = np.log(1 + max_releases) / (alpha - 1)
    epsilon_for_mechanisms = utility_epsilon - epsilon_for_bounded_length
    if epsilon_for_mechanisms <= 0: 
        raise ValueError(f"utility_epsilon is not large enough to accommodate extra privacy cost from bounded length. utility_epsilon was {utility_epsilon}, cost for bounded length is {epsilon_for_bounded_length}")

    t_p = ((1 - utility_variance_split) / utility_variance_split)
    sigma2_1 = alpha * utility_sensitivity**2 / (2 * epsilon_for_mechanisms) + 2 * alpha * utility_sensitivity**2 / (t_p * epsilon_for_mechanisms)
    sigma2_2 = t_p * sigma2_1

    return (sigma2_1**0.5, sigma2_2**0.5)


def run_until_private_utility_above_threshold_bounded_length(
        real_value: np.ndarray, validation_data: np.ndarray, 
        utility_function: Callable[[np.ndarray, np.ndarray, float, float], float], 
        utility_threshold: float,
        main_epsilons: list[float], utility_epsilon: float,
        utility_variance_split: float, alpha: float, main_sensitivity: float,
        utility_sensitivity: float,
        ) -> tuple[list[np.ndarray], float]:
    """Run the Brownian mechanism until utility evaluated on private validation 
    crosses a threshold.

    The threshold is checked using SVT with bounded length.

    Arguments for utility_function are (released value, validation_data, release variance, release epsilon).

    Args:
        real_value (np.ndarray): Private value to release.
        validation_data (np.ndarray): Validation data to compute utility.
        utility_function (Callable[[np.ndarray, np.ndarray, float, float], float]): Function that computes utility.
        utility_threshold (float): Acceptable utility threshold.
        main_epsilons (list[float]): Epsilons for Brownian mechanism.
        utility_epsilon (float): Epsilon for utility check.
        utility_variance_split (float): Fraction of noise variance allocated to the threshold mechanism.
        alpha (float): RDP alpha parameter.
        main_sensitivity (float): Sensitivity of real_value.
        utility_sensitivity (float): Sensitivity of utility.

    Returns:
        tuple[list[np.ndarray], float]: List of released values, final epsilon.
    """

    sigma_1, sigma_2 = calculate_mechanism_scales_for_svt_bounded_length(
        utility_epsilon, utility_variance_split, alpha, utility_sensitivity,
        len(main_epsilons) - 1,
    )

    utility_threshold_mechanism = lambda x: x + np.random.normal(0, sigma_1)
    utility_value_mechanism = lambda x: x + np.random.normal(0, sigma_2)

    return _run_until_private_utility_above_threshold_impl(
        real_value, validation_data, utility_function, utility_threshold,
        utility_threshold_mechanism, utility_value_mechanism,
        main_epsilons, alpha, main_sensitivity
    )


def calculate_max_utility_variance_split_nonneg_utility():
    return 1 / (1 + 3**0.5)


# Zhu and Wang (2020), Proposition 10
# Not used due to large overhead in privacy cost.
def calculate_mechanism_scales_for_svt_nonneg_utility(
        utility_epsilon: float, utility_variance_split: float,
        alpha: float, utility_sensitivity: float,
        utility_threshold: float
        ) -> tuple[float, float]:
    """Calculate noise scales for utility threshold check using SVT with 
    non-negative utility.

    Reference: Zhu and Wang "Improving Sparse Vector Technique with Renyi Differential Privacy"
    NeurIPS 2020, Proposition 10

    Args:
        utility_epsilon (float): Epsilon for SVT.
        utility_variance_split (float): Fraction of noise variance allocated to the threshold mechanism.
        alpha (float): RDP alpha parameter.
        utility_sensitivity (float): Sensitivity of utility value.
        utility_threshold (float): Acceptable utility threshold.

    Returns:
        tuple[float, float]: Threshold noise scale, utility value noise scale
    """

    t_p = np.sqrt(((1 - utility_variance_split) / utility_variance_split))
    if utility_variance_split > calculate_max_utility_variance_split_nonneg_utility():
        raise ValueError(f"utility_varince_split must be at most {calculate_max_utility_variance_split_nonneg_utility()}, got {utility_variance_split}.")

    min_epsilon = np.log(1 + 2 * 3**0.5 * np.pi / (2 * (alpha - 1)))
    if utility_epsilon < min_epsilon:
        raise ValueError(f"utility_epsilon must to greater than {min_epsilon}, got {utility_epsilon}.")

    def total_epsilon(sigma_1): 
        term1 = alpha * utility_sensitivity**2 / sigma_1**2 # TODO: check that not dividing by 2 is correct
        term2 = 2 * alpha * utility_sensitivity**2 / (t_p * sigma_1)**2
        term3 = np.exp(utility_threshold**2 / sigma_1**2)
        term4 = 1 + 9 * utility_threshold**2 / sigma_1**2
        term5 = 2 * (alpha - 1)
        term6 = np.log(1 + 2 * 3**0.5 * np.pi * term4 * term3)

        return term1 + term2 + term6 / term5

    sigma_1 = opt.brentq(lambda sigma_1: total_epsilon(sigma_1) - utility_epsilon, 1e-10, 1e10)
    if not isinstance(sigma_1, float): raise RuntimeError("Optimization failed")
    sigma_2 = t_p * sigma_1
    return (sigma_1, sigma_2)

def run_until_private_utility_above_threshold_nonneg_utility(
        real_value: np.ndarray, validation_data: np.ndarray, 
        utility_function: Callable[[np.ndarray, np.ndarray, float, float], float], 
        utility_threshold: float,
        main_epsilons: list[float], utility_epsilon: float,
        utility_variance_split: float, alpha: float, main_sensitivity: float,
        utility_sensitivity: float, 
        ) -> tuple[list[np.ndarray], float]:
    """Run the Brownian mechanism until utility evaluated on private validation 
    crosses a threshold.

    The threshold is checked using SVT with non-negative utility.

    Arguments for utility_function are (released value, validation_data, release variance, release epsilon).

    Args:
        real_value (np.ndarray): Private value to release.
        validation_data (np.ndarray): Validation data to compute utility.
        utility_function (Callable[[np.ndarray, np.ndarray, float, float], float]): Function that computes utility.
        utility_threshold (float): Acceptable utility threshold.
        main_epsilons (list[float]): Epsilons for Brownian mechanism.
        utility_epsilon (float): Epsilon for utility check.
        utility_variance_split (float): Fraction of noise variance allocated to the threshold mechanism.
        alpha (float): RDP alpha parameter.
        main_sensitivity (float): Sensitivity of real_value.
        utility_sensitivity (float): Sensitivity of utility.

    Returns:
        tuple[list[np.ndarray], float]: List of released values, final epsilon.
    """

    sigma_1, sigma_2 = calculate_mechanism_scales_for_svt_nonneg_utility(
        utility_epsilon, utility_variance_split, alpha, utility_sensitivity,
        utility_threshold
    )

    utility_threshold_mechanism = lambda x: x + np.random.normal(0, sigma_1)
    utility_value_mechanism = lambda x: x + np.random.normal(0, sigma_2)

    return _run_until_private_utility_above_threshold_impl(
        real_value, validation_data, utility_function, utility_threshold,
        utility_threshold_mechanism, utility_value_mechanism,
        main_epsilons, alpha, main_sensitivity
    )


def _run_until_private_utility_above_threshold_impl(
        real_value: np.ndarray, validation_data: np.ndarray, 
        utility_function: Callable[[np.ndarray, np.ndarray, float, float], float], 
        utility_threshold: float,
        utility_threshold_mechanism: Callable[[float], float], 
        utility_value_mechanism: Callable[[float], float],
        main_epsilons: list[float], alpha: float, main_sensitivity: float,
        ) -> tuple[list[np.ndarray], float]:

    mechanism = BrownianMechanism(real_value, alpha, main_sensitivity)
    released_values = []
    current_epsilon = 0.0

    noisy_utility_threshold = utility_threshold_mechanism(utility_threshold)

    for epsilon in main_epsilons:
        release = mechanism.release_value(epsilon)
        released_values.append(release)
        current_epsilon = epsilon
        release_variance = gaussian_mechanism_rdp_variance(current_epsilon, alpha, main_sensitivity)

        utility = utility_function(release, validation_data, release_variance, current_epsilon)
        noisy_utility = utility_value_mechanism(utility)
        if noisy_utility >= noisy_utility_threshold:
            break

    return released_values, current_epsilon


def calculate_mechanism_scale_no_svt(
        utility_epsilon: float, utility_sensitivity: float, max_release_count: int,
        alpha: float
        ) -> float:
    """Calculate noise scale for utility threshold check with plain Gaussian
    mechanism.

    Args:
        utility_epsilon (float): Epsilon for threshold check.
        utility_sensitivity (float): Sensitivity of utility value.
        max_release_count (int): Maximum number of releases. Must be at least 2.
        alpha (float): RDP alpha parameter.

    Returns:
        float: Noise scale for Gaussian mechanism.
    """

    if max_release_count < 2:
        raise ValueError(f"max_release_count must be at least 2. Got {max_release_count}")
    return gaussian_mechanism_rdp_variance(utility_epsilon / (max_release_count - 1), alpha, utility_sensitivity)**0.5


def run_until_private_utility_above_threshold_no_svt(
        real_value: np.ndarray, validation_data: np.ndarray,
        utility_function: Callable[[np.ndarray, np.ndarray, float, float], float],
        utility_threshold: float, utility_epsilon: float, utility_sensitivity: float,
        main_epsilons: list[float], alpha: float, main_sensitivity: float,
        ) -> tuple[list[np.ndarray], float]:
    """Run the Brownian mechanism until utility evaluated on private validation 
    crosses a threshold.

    The threshold is checked using a plain Gaussian mechanism that releases the utility value.

    Arguments for utility_function are (released value, validation_data, release variance, release epsilon).

    Args:
        real_value (np.ndarray): Private value to release.
        validation_data (np.ndarray): Validation data to compute utility.
        utility_function (Callable[[np.ndarray, np.ndarray, float, float], float]): Function that computes utility.
        utility_threshold (float): Acceptable utility threshold.
        utility_epsilon (float): Epsilon for utility check.
        utility_sensitivity (float): Sensitivity of utility.
        main_epsilons (list[float]): Epsilons for Brownian mechanism.
        alpha (float): RDP alpha parameter.
        main_sensitivity (float): Sensitivity of real_value.

    Returns:
        tuple[list[np.ndarray], float]: List of released values, final epsilon.
    """

    mechanism = BrownianMechanism(real_value, alpha, main_sensitivity)
    released_values = []
    current_epsilon = 0.0

    utility_sigma = calculate_mechanism_scale_no_svt(utility_epsilon, utility_sensitivity, len(main_epsilons), alpha)

    for i, epsilon in enumerate(main_epsilons):
        release = mechanism.release_value(epsilon)
        released_values.append(release)
        current_epsilon = epsilon
        release_variance = gaussian_mechanism_rdp_variance(current_epsilon, alpha, main_sensitivity)

        utility = utility_function(release, validation_data, release_variance, current_epsilon)
        if i < len(main_epsilons) - 1: # Skip utility check for last epsilon
            noisy_utility = utility + np.random.normal(0, utility_sigma)
            if noisy_utility >= utility_threshold:
                break

    return released_values, current_epsilon


def get_noise_scales_no_svt(
        utility_epsilon: float, final_epsilon: float, utility_sensitivity: float, 
        main_sensitivity: float, max_release_count: int, alpha: float
        ) -> tuple[float, float]:
    """Get noise scales for plain Gaussian
    mechanism.

    Args:
        utility_epsilon (float): Epsilon for threshold check.
        final_epsilon (float): Final epsilon of Brownian mechanism.
        utility_sensitivity (float): Sensitivity of utility value.
        main_sensitivity (float): Sensitivity of real_value.
        max_release_count (int): Maximum number of releases. Must be at least 2.
        alpha (float): RDP alpha parameter.

    Returns:
        tuple[float, float]: Utility check noise scale, Brownian mechanism noise scale.
    """

    utility_sigma = calculate_mechanism_scale_no_svt(
        utility_epsilon, utility_sensitivity, max_release_count, alpha
    )
    main_sigma = gaussian_mechanism_rdp_variance(final_epsilon, alpha, main_sensitivity)**0.5
    return utility_sigma, main_sigma


def get_epsilons_with_noise_scales_no_svt(
        utility_sigma: float, main_sigma: float, utility_sensitivity: float,
        main_sensitivity: float, max_release_count: int, alpha: float
        ) -> tuple[float, float]:
    """Get epsilons for plain Gaussian mechanism from noise scales.

    Args:
        utility_sigma (float): Noise scale for utility mechanism.
        main_sigma (float): Noise scale for Brownian mechanism.
        utility_sensitivity (float): Sensitivity of utility value.
        main_sensitivity (float): Sensitivity of released value.
        max_release_count (int): Maximum number of releases. Must be at least 2.
        alpha (float): RDP alpha parameter.

    Returns:
        tuple[float, float]: Epsilon for utility mechanism, epsilon for Brownian mechanism.
    """

    utility_epsilon = gaussian_mechanism_rdp_epsilon(utility_sigma, alpha, utility_sensitivity) * (max_release_count - 1)
    main_epsilon = gaussian_mechanism_rdp_epsilon(main_sigma, alpha, main_sensitivity)

    return utility_epsilon, main_epsilon