#Code for noise reduction mechanisms
import math
import numpy as np
import numpy.random as rd
from scipy.optimize import newton_krylov, minimize, root_scalar
import matplotlib.pyplot as plt 

'''
Returns first index in a list that is above a target threshold
values : list of values
threshold : target threshold
'''
def get_bars_basic(values, thresh):
    for i in range(len(values)):
        if values[i] >= thresh:
            return i
    return len(values)

'''
Runs Laplace Noise Reduction to mask a hidden parameter
eps : list of privacy parameters
l1_sens : sensitivity of the function producing the hidden parameter
beta : hidden parameter
'''
def LNR(eps, l1_sens, beta):
    d = len(beta)
    params = [np.zeros(d)]*len(eps)
    params[-1] = rd.laplace(beta, l1_sens/eps[-1])
    for i in range(len(eps) - 1):
        coin = rd.binomial(1, (eps[-(2 + i)]/eps[-(1 + i)])**2)
        if coin:
            params[-(2 + i)] = params[-(1 + i)]
        else:
            params[-(2 + i)] = rd.laplace(params[-(1 + i)], l1_sens/eps[-(2 + i)])
    return params

'''
Runs Brownian Mehcanism to mask a hidden parameter
eps : list of target privacy parameters
beta : parameter to hid
f : function mapping privacy parameters to noise levels
'''        
def BM(eps,  beta, f):
    d = len(beta)
    params = [np.zeros(d)]*len(eps)
    params[-1] = rd.normal(beta, math.sqrt(f(eps[-1])))
    for i in range(len(eps) - 1):
        #removed in variance math.sqrt(l2_sens).
        params[-(2 + i)] = params[-(i + 1)] + rd.normal(np.zeros(d), math.sqrt(f(eps[-(2 + i)]) - f(eps[-(1 + i)])))
        #norm = math.sqrt((params[-(2 + i)] - params[-(i + 1)]) @ (params[-(2 + i)] - params[-(i + 1)]))
    return params




'''
Runs AboveThreshold to privately asses the first time a parameter obtains utility above a target
threshold
thresh : target threshold
sens : sensitivity of the utility function
epsilon : noise level to add via AboveThreshold
values : list of values to check if are above threshold
'''
def AboveThreshold(thresh, sens, epsilon, values):
    thresh = rd.laplace(thresh, 2.0*sens/epsilon)
    for i in range(len(values)):
        if rd.laplace(values[i], 4.0*sens/epsilon) >= thresh:
            return i
    return len(values) - 1

'''
Runs ReducedAboveThreshold to privately asses the first time a parameter obtains utility above a target
threshold
thresh : target threshold
sens : sensitivity of the utility function
eps : list of privacy parameters to use in each round for RAT
values : list of values to check if are above threshold
'''  
def ReducedAboveThreshold(thresh, sens, eps, values):
    eps = 2.0*eps
    threshes = LNR(eps, sens, np.array([thresh]))
    for i in range(len(values)):
        if rd.laplace(values[i], 4.0*sens/eps[i]) >= threshes[i]:
            return i    
    return len(values) - 1
    

'''
Optimizes the parameter rho for the mixture privacy boundary for a target privacy level
eps : privacy parameter to optimize boundary tightness at
t : function (representing mixture boundary) taking privacy parameters to corresponding times
'''
def optimize_rho(eps, t):
    f = lambda x : t(eps, x)
    rho = (minimize(f, eps, bounds=[(0.1, 100)]).x)[0]
    return rho

'''
Optimizes the parameter a for the linear privacy boundary for a target privacy level
eps : privacy parameter to optimize boundary tightness at
t : function (representing linear boundary) taking privacy parameters to corresponding times
sens : l2 sensitivity -- gives upper bound on range to optimize over
'''
def optimize_a(eps, t, sens):
    f = lambda x : t(eps, x)
    a = (minimize(f, (0.5)*eps/sens, bounds=[(0.0001, eps/sens)]).x)[0]
    return a


'''
Map target privacy levels to noise levels via mixture privacy boundary
eps : input privacy level
rho : parameter in mixture boundary
delta : probability of failure
l2_sens : l2 sensitivity.
'''
def mixture_e_to_t(eps, rho, delta, l2_sens):
    f = lambda x: ((l2_sens/x)*math.sqrt(2*(x + rho)*math.log(float(math.sqrt((x + rho)/rho)*(1/delta)))) + (l2_sens*l2_sens)/(2*x) - eps)
    #t = newton_krylov(f, eps)
    t = root_scalar(f, method="bisect", bracket=[0.000001, 400000.00])
    return t.root

'''
Mixture privacy boundary evaluated at time (variance) t.
t : time to evaluate the boundary at.
rho : parameter in mixture boundary
delta : probability of failure
l2_sens : l2 sensitivity.

'''

def mixture(t, rho, delta, l2_sens):
    return (l2_sens/t)*math.sqrt(2*(t + rho)*math.log(float(math.sqrt((t + rho)/rho)*(1/delta)))) + (l2_sens*l2_sens)/(2*t)

'''
Linear privacy boundary evaluated at time (variance) t.
t : time to evaluate the boundary at.
a : parameter in linear boundary
delta : probability of failure
l2_sens : l2 sensitivity.

'''
def linear(t, a, delta, l2_sens):
    return (l2_sens/t)*(l2_sens/2 + 1/(2*a)*math.log(1/delta)) + l2_sens*a


'''
Map target privacy levels to noise levels via linear privacy boundary
eps : input privacy level
a : parameter in mixture boundary
delta : probability of failure
l2_sens : l2 sensitivity.
'''

def linear_e_to_t(eps, a, delta, l2_sens):
    return (l2_sens*(l2_sens/2 + 1/(2*a)*math.log(1/delta)))/(eps - l2_sens * a)

'''
def gen_ploss_loss(beta, f, X, y, loss, l2_sens, l1_sens, window=(0.1, 5.0), steps=100, target=0.9, show_target=False, method="NONE", epsilon=1.0):
    spread = window[1] - window[0]
    eps = np.array([window[0] + n*(spread/steps) for n in range(steps)])
    betas_1 = LNR(eps, l1_sens, beta)
    betas_2 = BM(eps, l2_sens, beta, f)
    acc_1 = [loss(X_test, y_test, betas_1[i]) for i in range(steps)]
    acc_2 = [loss(X_test, y_test, betas_2[i]) for i in range(steps)]
    if show_target:
        if method == "NONE":
            i_1, i_2 = _get_bars_basic(acc_1, target), _get_bars_basic(acc_2, target)
        elif method == "AT":
            i_1, i_2 = AboveThreshold(target, 1.0/len(X), epsilon, acc_1), AboveThreshold(target, 1.0/len(X), epsilon, acc_2)
            eps = eps + epsilon
        else:
            i_1, i_2 = ReducedAboveThreshold(target, 1.0/len(X), eps, acc_1), ReducedAboveThreshold(target, 1.0/len(X), eps, acc_2)
            eps = eps + eps
    return eps, ((acc_1, i_1), (acc_2, i_2))

'''
