# -*- coding: utf-8 -*-
"""
Created on Sun Sep 25 17:10:49 2022
"""

from eps_delta_edgeworth import *
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

from prv_accountant import PRVAccountant
from prv_accountant.privacy_random_variables import PoissonSubsampledGaussianMechanism, GaussianMechanism, LaplaceMechanism

sigma_1 = 0.8
mu_1 = 1 / sigma_1
def log_likelihood_ratio_func(x):
    p = p_1
    sigma = sigma_1
    mu = mu_1
    if x > 0:
        return mu * x + np.log((1 - p) * np.exp(-mu * x) + p * np.exp(- mu * mu / 2))
    return np.log(1 - p + p * np.exp(mu * x - mu * mu / 2))
def dens_func_X(x):
    return scipy.stats.norm.pdf(x)
def dens_func_Y(x):
    p = p_1
    sigma = sigma_1
    mu = mu_1
    return (1 - p) * scipy.stats.norm.pdf(x) + p * scipy.stats.norm.pdf(x, loc = mu)


gm_x_1 = Distribution(dens_func_X, log_likelihood_ratio_func, 2)
gm_y_1 = Distribution(dens_func_Y, log_likelihood_ratio_func, 2)

#
sigma_2 = 0.8
mu_2 = 1 / sigma_2
def log_likelihood_ratio_func_2(x):
    p = p_2
    sigma = sigma_2
    mu = mu_2
    if x > 0:
        return mu * x + np.log((1 - p) * np.exp(-mu * x) + p * np.exp(- mu * mu / 2))
    return np.log(1 - p + p * np.exp(mu * x - mu * mu / 2))
def dens_func_X_2(x):
    return scipy.stats.norm.pdf(x)
def dens_func_Y_2(x):
    p = p_2
    sigma = sigma_2
    mu = mu_2
    return (1 - p) * scipy.stats.norm.pdf(x) + p * scipy.stats.norm.pdf(x, loc = mu)


gm_x_2 = Distribution(dens_func_X_2, log_likelihood_ratio_func_2, 2)
gm_y_2 = Distribution(dens_func_Y_2, log_likelihood_ratio_func_2, 2)



n1_list = [50000, 100000, 166666, 230000, 235000, 240000, 250000, 277777, 281690, 285714, 294117, 312500, 333333, 400000, 500000]
n2_list = [500000, 1000000, 1666666, 2300000, 2350000, 2400000, 2500000, 2777777, 2816901, 2857142, 2941176, 3125000, 3333333, 4000000, 5000000]
delta = 0.1

## FFT:
p1 = 0.35
p2 = 0.02

lowers, uppers, ests = [], [], []
for n1, n2 in zip(n1_list, n2_list):
  prv_a = PoissonSubsampledGaussianMechanism(noise_multiplier=0.8, sampling_probability=p1 / n1 ** 0.5)
  prv_b = PoissonSubsampledGaussianMechanism(noise_multiplier=0.8, sampling_probability=p2 / n2 ** 0.5)
  accountant = PRVAccountant(
    prvs=[prv_a, prv_b],
    max_self_compositions=[500000, 5000000],
    eps_error=0.1,
    delta_error=1e-2
  )
  eps_low, eps_est, eps_up = accountant.compute_epsilon(delta=1e-1, num_self_compositions=[n1, n2])
  lowers.append(eps_low)
  uppers.append(eps_est)
  ests.append(eps_up)


##EEAI
distribution_list_x = [gm_x_1, gm_x_2]
distribution_list_y = [gm_y_1, gm_y_2]


lower, upper, est = [], [], []
for n1, n2 in zip(n1_list, n2_list):
    p_1 = 0.35 / n1 ** 0.5
    p_2 = 0.02 / n2 ** 0.5
    mechanism = HeterogeneousComposition(distribution_list_x, distribution_list_y, order = 2)
    eest, elow, eupp = mechanism.approx_eps_from_delta_edgeworth(delta, numbers_list = [n1, n2])
    print(eest, elow, eupp)
    est.append(eest)
    lower.append(elow)
    upper.append(eupp)

n_list = [(n1 + n2) for n1, n2 in zip(n1_list, n2_list)]


number_lst = n_list



figure(figsize=(4, 4))


plt.plot(number_lst[:5], lowers[:5], label = "FFT_LOW", linestyle = "dashed", color = "black")
plt.plot(number_lst[:5], uppers[:5], label = "FFT_UPP", linestyle = "dashed", color = "black")
#plt.plot(number_lst, eps_gdp, label = "GDP")
#plt.plot(number_lst, eps_rdp, label = "RDP")
plt.plot(number_lst, est, label = "EW_EST")
plt.plot(number_lst, upper, label = "EW_UPP")
plt.plot(number_lst, lower, label = "EW_LOW")

#ax.plot(points, eps_estimate, label = "EPS_EST")
#ax.plot(points, edgeeps, label = "Edgeworth (2nd)")
#ax.plot(points, edgeeps3, label = "Edgeworth (3rd)")
#ax.plot(points, edgeeps_prv, label = "Edgeworth")

plt.legend(fontsize=10)
plt.ylabel(r"$\epsilon$", fontsize=15, rotation=90)
plt.xlabel("m ($\mathregular{10^6}$)", fontsize=15)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

#plt.title(r"$\delta$" + f" = {delta}")
#plt.title("Eps as function of iterations.")

plt.savefig(f"Heterogeneous_numerical_issue_with_FFT_zoom_in.pdf", format='pdf', bbox_inches = 'tight')

