#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 17 14:23:18 2024

@author: jay
"""

#Comparison between our method and that of Wang et al
import math
import numpy as np
from scipy.special import comb as comb
import matplotlib.pyplot as plt

from typing import List, Tuple, Union
import warnings

q=120./50000.

sigma=6.

r_over_sigma_tilde=2./sigma

def M_exact(k):
    M=(-1)**(k-1)*(k-1)
    
    for ell in range(2,k+1):
        M=M+(-1)**(k-ell)*comb(k, ell, exact=True)*np.exp(ell*(ell-1)*r_over_sigma_tilde**2/2)
    return M


def B_bound(m):
    if m%2==0:
        return M_exact(m)
    else:
        return M_exact(m-1)**(1/2)*M_exact(m+1)**(1/2)

def A_bound(alpha,m):
    if m%2==0:
        A=0
        for ell in range(m+1):
            A=A+comb(m,ell,exact=True)*(-1)**(m-ell)*np.exp((alpha+ell-m-1)*(alpha+ell-m)*r_over_sigma_tilde**2/2)
        
        return A
    else:
        return A_bound(alpha,m-1)**(1/2)*A_bound(alpha,m+1)**(1/2)

def R_bound(alpha,m):
    alpha_prod=alpha
    for j in range(1,m):
        alpha_prod=alpha_prod*(alpha-j)
    if 0<alpha-m<1:
        return q**m/np.math.factorial(m)*alpha_prod*(A_bound(alpha,m)+B_bound(m))

    else:
        return q**m/np.math.factorial(m)*alpha_prod*(q/(m+1)*A_bound(alpha,m)+(1-q/(m+1))*B_bound(m))
            

def H_bound(alpha,m):
    H_terms=[1.]
    alpha_prod=alpha
    for k in range(2,m):
        alpha_prod=alpha_prod*(alpha-k+1)
        H_terms.append(q**k/np.math.factorial(k)*alpha_prod*M_exact(k))
    H_terms.append(R_bound(alpha,m))
    
    H=0.
    for j in range(len(H_terms)):
       H=H+H_terms[len(H_terms)-1-j] 
    return H

#the bound performs poorly when alpha-m\in(0,1), so in that case we increase m by 1 to avoid that case
def RDP_eps_one_step(alpha,m):
    if 0<alpha-m<1:
        return 1/(alpha-1)*np.log(H_bound(alpha,m+1))
    else:
        return 1/(alpha-1)*np.log(H_bound(alpha,m))

##implement bound from Wang_et_al

def K_Wang(alpha):
    K_terms=[1.]
    alpha_prod=alpha*(alpha-1)
    
    K_terms.append(2*alpha_prod*q**2*(np.exp(4/sigma**2)-1))
    
    for j in range(3,alpha+1):
        alpha_prod=alpha_prod*(alpha-j+1)
        K_terms.append(2*q**j*alpha_prod/np.math.factorial(j)*np.exp((j-1)*2*j/sigma**2))

    K=0
    for j in range(len(K_terms)):
        K=K+K_terms[len(K_terms)-1-j] 
    K=np.log(K)
    return K

def Wang_et_al_upper_bound(alpha):
    if alpha>=2:
        if int(alpha)==alpha:
            return 1./(alpha-1.)*K_Wang(alpha)
        else:
            return (1.-(alpha-math.floor(alpha)))/(alpha-1)*K_Wang(math.floor(alpha))+(alpha-math.floor(alpha))/(alpha-1)*K_Wang(math.floor(alpha)+1)
    else:
        return Wang_et_al_upper_bound(2)

def Wang_et_al_lower_bound(alpha):
    if int(alpha)==alpha:
        L_terms=[1.]
        L_terms.append(alpha*q/(1-q))
        alpha_prod=alpha
        for j in range(2,alpha+1):
            alpha_prod=alpha_prod*(alpha-j+1)
            L_terms.append(alpha_prod/np.math.factorial(j)*(q/(1-q))**j*np.exp((j-1)*2*j/sigma**2))
        
        
        L=0
        for j in range(len(L_terms)):
            L=L+L_terms[len(L_terms)-1-j]         
        return alpha/(alpha-1)*np.log(1-q)+1/(alpha-1)*np.log(L)
    else:
        print("Error, alpha must be an integer.")


def get_eps(*, orders: Union[List[float], float], rdp
            : Union[List[float], float], delta: float) -> Tuple[float, float]:
    r"""Based on:
    Borja Balle et al. "Hypothesis testing interpretations and Renyi differential privacy."
    International Conference on Artificial Intelligence and Statistics. PMLR.
    Particullary, Theorem 21 in the arXiv version https://arxiv.org/abs/1905.09982.
    """
    orders_vec = np.atleast_1d(orders)
    rdp_vec = np.atleast_1d(rdp)

    if len(orders_vec) != len(rdp_vec):
        raise ValueError(
            f"Input lists must have the same length.\n"
            f"\torders_vec = {orders_vec}\n"
            f"\trdp_vec = {rdp_vec}\n"
        )

    eps = (
        rdp_vec
        - (np.log(delta) + np.log(orders_vec)) / (orders_vec - 1)
        + np.log((orders_vec - 1) / orders_vec)
    )

    # special case when there is no privacy
    if np.isnan(eps).all():
        return np.inf, np.nan

    idx_opt = np.nanargmin(eps)  # Ignore NaNs
    if idx_opt == 0 or idx_opt == len(eps) - 1:
        extreme = "smallest" if idx_opt == 0 else "largest"
        warnings.warn(
            f"Optimal order is the {extreme} alpha. Please consider expanding the range of alphas to get a tighter privacy bound. delta is: {delta}"
        )
    return eps[idx_opt], orders_vec[idx_opt]


N_alpha=500

alpha_array=1+10**np.linspace(-1,1.5,N_alpha)

m_array=[3,4,6,8,10]
N_m=len(m_array)

deltas = [1e-10, 3e-10, 5e-10, 1e-9, 3e-9, 5e-9, 1e-8, 3e-8, 5e-8, 1e-7, 3e-7, 5e-7, 1e-6, 3e-6, 5e-6, 1e-5, 3e-5, 5e-5, 1e-4]
n_epochs = 250
n_train = 50000
batch_size = 120
accu_factor = n_train / batch_size * n_epochs

eps_array=np.zeros((N_m,N_alpha))

for j1 in range(N_m):
    for j2 in range(N_alpha):
        eps_array[j1,j2]=RDP_eps_one_step(alpha_array[j2],m_array[j1])

Wang_eps_array=np.zeros(N_alpha)
for j1 in range(N_alpha):
    Wang_eps_array[j1]=Wang_et_al_upper_bound(alpha_array[j1])
    
    
alpha_lb_array=[]
Wang_eps_lb_array=[]
for alpha in range(2,int(max(alpha_array))+1):
    alpha_lb_array.append(alpha)
    Wang_eps_lb_array.append(Wang_et_al_lower_bound(alpha))

plt.figure()
labels=[]
for k in range(N_m):
    plt.semilogy(alpha_array,eps_array[k,:])
    labels.append("FS-RDP (m = {})".format(m_array[k]))

plt.semilogy(alpha_array,Wang_eps_array,'--')
labels.append("Wang et al. upper bound")
plt.semilogy(alpha_lb_array,Wang_eps_lb_array,'.')
labels.append("Wang et al. lower bound")

plt.grid(color = 'gray', linestyle = '--', linewidth = 0.5)
plt.legend(labels)
plt.xlabel(r'$\alpha$')
plt.ylabel('One-step RDP Bound')
##plt.show()
plt.savefig("comp_with_Wang_et_al.pdf")

### plot guarantees
plt.figure()
labels=[]
for k in range(N_m):
    eps_ours = []
    rdp = eps_array[k, :] * accu_factor
    for d in deltas:
        ep, alpha = get_eps(orders = alpha_array, rdp = rdp, delta = d)
        eps_ours.append(ep)
    plt.semilogx(deltas, eps_ours, marker ='1')
    labels.append('FS$_{woR}$-RDP'+ '(m = {})'.format(m_array[k]))

eps_wang_up = []
for d in deltas:
    rdp = Wang_eps_array * accu_factor
    ep, alpha = get_eps(orders = alpha_array, rdp = rdp, delta = d)
    eps_wang_up.append(ep)
plt.semilogx(deltas, eps_wang_up, linestyle='--')
labels.append('Upper bound (Wang et al.)')

eps_wang_low = []
for d in deltas:
    rdp = np.array(Wang_eps_lb_array) * accu_factor
    ep, alpha = get_eps(orders = alpha_lb_array, rdp = rdp, delta = d)
    eps_wang_low.append(ep)
plt.semilogx(deltas, eps_wang_low, '.')
labels.append('Lower bound (Wang et al.)')

plt.xlabel(r'$\delta$')
plt.ylabel(r'$\epsilon$', rotation =0)
plt.legend(labels)
plt.grid(color = 'gray', linestyle = '--', linewidth = 0.5)
plt.show()    
plt.savefig("guarantee.pdf")
    
    
    
