from scipy.stats import norm
import numpy as np

# Calculate h and r recursively (no abstentions)
def rh(inverse_blow_up_function, alpha, beta, j, m, k=2):
    # Initialize lists to store h and r values
    h = [0 for _ in range(j + 1)]
    r = [0 for _ in range(j + 1)]
    # Set initial values for h and r
    h[j] = beta
    r[j] = alpha
    # Iterate from j-1 to 0
    for i in range(j - 1, -1, -1):
        # Calculate h[i] using the maximum of h[i+1] and a scaled inverse blow-up function
        h[i] = max(h[i + 1], (k - 1) * inverse_blow_up_function(r[i + 1]))
        # Update r[i] based on the difference between h[i] and h[i+1]
        r[i] = r[i + 1] + (i / (m - i)) * (h[i] - h[i + 1])
    # Return the lists of h and r values
    return (r, h)

# Audit function without abstention
def audit_rh(inverse_blow_up_function, m, c, threshold=0.05, k=2):
    # Calculate alpha and beta values
    alpha = threshold * c / m
    beta = threshold * (m - c) / m
    # Call the rh function to get the lists of h and r values
    r, h = rh(inverse_blow_up_function, alpha, beta, c, m, k)
    # Check if the differential privacy condition is satisfied
    # print(r[0], h[0])
    if r[0] + h[0] > 1.0:
        return False
    else:
        return True

# Calculate h and r recursively (with abstentions)
def rh_with_cap(inverse_blow_up_function, alpha, beta, j, m,c_cap, k=2):
    h=[0 for i in range(j+1)]
    r=[0 for i in range(j+1)]
    h[j]= beta
    r[j]= alpha
    for i in range(j-1,-1,-1):
        h[i]=max(h[i+1],(k-1)*inverse_blow_up_function(r[i+1]))
        r[i]= r[i+1] + (i/(c_cap-i))*(h[i] - h[i+1])
    
    return (r,h)

# Audit function with abstentions
def audit_rh_with_cap(inverse_blow_up_function, m, c,c_cap, threshold=0.05, k=2):
    threshold=threshold*c_cap/m
    alpha=(threshold*c/c_cap)
    beta=threshold*(c_cap-c)/c_cap
    r,h=rh_with_cap(inverse_blow_up_function, alpha, beta, c, m, c_cap, k)
    if r[0]+h[0]>c_cap/m:
        return False
    else: 
        return True
        
# Calculate the blow-up function for Gaussian noise
def gaussianDP_blow_up_function(noise):
    def blow_up_function(x):
        # Calculate the threshold value
        threshold = norm.ppf(x)
        # Calculate the blown-up threshold value
        blown_up_threshold = threshold + 1 / noise
        # Return the CDF of the blown-up threshold value
        return norm.cdf(blown_up_threshold)
    return blow_up_function

# Calculate the inverse blow-up function for Gaussian noise
def gaussianDP_blow_up_inverse(noise):
    def blow_up_inverse_function(x):
        # Calculate the threshold value
        threshold = norm.ppf(x)
        # Calculate the blown-up threshold value
        blown_up_threshold = threshold - 1 / noise
        # Return the CDF of the blown-up threshold value
        return norm.cdf(blown_up_threshold)
    return blow_up_inverse_function

# Define a function to calculate delta for Gaussian noise
def calculate_delta_gaussian(noise, epsilon):
    # Calculate delta using the formula
    delta = norm.cdf(-epsilon * noise + 1 / (2 * noise)) - np.exp(epsilon) * norm.cdf(-epsilon * noise - 1 / (2 * noise))
    return delta

# Define a function to calculate epsilon for Gaussian noise
def calculate_epsilon_gaussian(noise, delta):
    # Set initial bounds for epsilon
    epsilon_upper = 100
    epsilon_lower = 0
    # Perform binary search to find epsilon
    while epsilon_upper - epsilon_lower > 0.001:
        epsilon_middle = (epsilon_upper + epsilon_lower) / 2
        if calculate_delta_gaussian(noise, epsilon_middle) > delta:
            epsilon_lower = epsilon_middle
        else:
            epsilon_upper = epsilon_middle
    # Return the upper bound of epsilon
    return epsilon_upper

# Get the empirical epsilon value
def get_gaussian_emp_eps_ours(candidate_noises, inverse_blow_up_functions, m, c, c_cap, threshold, delta, k=2):
    # Initialize the empirical privacy index
    empirical_privacy_index = 0
    # Iterate through candidate noises until the privacy condition fails
    while audit_rh_with_cap(inverse_blow_up_functions[empirical_privacy_index], m, c, c_cap, threshold=threshold, k=k):
        empirical_privacy_index += 1
        # did not find epsilon
        if empirical_privacy_index >= len(inverse_blow_up_functions):
         return -1
    # Get the empirical noise and calculate the empirical epsilon\
    empirical_noise = candidate_noises[empirical_privacy_index]
    empirical_eps = calculate_epsilon_gaussian(empirical_noise, delta=delta)
    # Return the empirical epsilon
    return empirical_eps

# Get the empirical epsilon value with binary search
def get_gaussian_emp_eps_bs(candidate_noises, inverse_blow_up_functions, m, c, c_cap, threshold, delta, k=2):
    left, right = 0, len(inverse_blow_up_functions) - 1
    while left <= right:
        mid = (left + right) // 2
        if audit_rh_with_cap(inverse_blow_up_functions[mid], m, c, c_cap, threshold=threshold, k=k):
            left = mid + 1  # Move right to continue searching
        else:
            right = mid - 1  # Move left to find the first failing index
    empirical_privacy_index = left

    # did not find epsilon
    if empirical_privacy_index >= len(inverse_blow_up_functions):
        return -1
    # Get the empirical noise and calculate the empirical epsilon\
    empirical_noise = candidate_noises[empirical_privacy_index]
    empirical_eps = calculate_epsilon_gaussian(empirical_noise, delta=delta)
    # Return the empirical epsilon
    return empirical_eps

MAX_NOISE = 1000
def get_target_noise(target_eps):
    noise_min_eps = MAX_NOISE 
    noise_max_eps = 0.1
    for _ in range(100):
        noise_med = (noise_min_eps + noise_max_eps) / 2
        eps = calculate_epsilon_gaussian(noise_med, delta=1e-5)
        if eps < target_eps:
            noise_min_eps = noise_med
        else:
            noise_max_eps = noise_med
    return noise_min_eps