import time
import scipy
import scipy.signal
from scipy.stats import norm
from scipy.stats import laplace
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy import integrate
import math
from tqdm import tqdm

### define privacy_params class
class privacy_params():
    num_JL=None
    sigma=None
    num_samples=None
    batch_size=None
    epochs=None
    mesh_size=None
    delta=None
    # t_limit_stdPQ=None
    # t_limit_stdPQ_single_iter=None
    max_eps=None
    min_delta=None
    precision=None
    def __init__(self,sigma,num_samples,batch_size,epochs,delta,num_JL=None,mesh_size=None,max_eps=100,min_delta=10**(-10),precision=10**(-16)): 
        self.num_JL=num_JL
        self.epochs=epochs
        self.sigma=sigma
        self.num_samples=num_samples
        self.batch_size=batch_size
        if mesh_size==None:
            self.mesh_size = (batch_size/num_samples)**2 * (math.exp(1/sigma**2)-1)/2 /10
            print('mesh_size: ',self.mesh_size)
        else:
            self.mesh_size=mesh_size
        self.delta=delta
        self.max_eps=max_eps
        self.min_delta=min_delta
        self.precision=precision
    

    def num_iter(self):
        return int(self.epochs*math.ceil(self.num_samples/self.batch_size))

    def sampling_prob(self):
        return self.batch_size/self.num_samples


class privacy_engine_dpdl():
    privacy_args=None
    precision=None
    stdQ_pdf_single_iter=None
    stdQ_pdf_single_epoch=None

    stdQ_pdf=None

    stdQ_pdf_normalized=None

    t_list=None
    t_limit_stdPQ=None

    exp_minust_list=None


    verbose=False

    current_epoch=None

    min_delta_limit=10**(-12)

    def __init__(self,privacy_args,verbose=False,interpolate_flag=True):
        self.privacy_args=privacy_args
        self.precision=privacy_args.precision
        self.verbose=verbose
        if verbose:
            print('Starting some privacy precomputations...')


     
        start_time=time.time()
        
        if self.privacy_args.num_JL == None:
            self.stdQ_pdf_single_iter,self.t_list,self.t_limit_stdPQ=  calculate_stdQ_single_iter_DPSGD(privacy_args,verbose=verbose,interpolate_flag=interpolate_flag)
        else:
            self.stdQ_pdf_single_iter,self.t_list,self.t_limit_stdPQ=  calculate_stdQ_single_iter_DPSGDJL(privacy_args,verbose=verbose)


        self.exp_minust_list = np.array([math.exp(-t) if abs(t)<700 else 0 for t in self.t_list])




        assert(len(self.stdQ_pdf_single_iter)==len(self.t_list))

        self.stdQ_pdf_single_epoch =  fast_convolve(self.stdQ_pdf_single_iter,int(math.ceil(privacy_args.num_samples/privacy_args.batch_size)))
        
        self.current_epoch=0

        if verbose:
            print('Time: ',time.time()-start_time)
            print('Finished privacy precomputations. Increase mesh_size if this is too slow.')

    def add_epochs(self,N=1):
        start=time.time()
        ########################### Updated privacy curves by adding N epochs #################
        stdQ_pdf_temp=fast_convolve(self.stdQ_pdf_single_epoch,N)

        if self.current_epoch==0:
            self.stdQ_pdf=stdQ_pdf_temp
        else:
            self.stdQ_pdf = np.real(scipy.signal.convolve(self.stdQ_pdf,stdQ_pdf_temp,'same'))
        
        self.current_epoch+=N

        self.stdQ_pdf = [p if p>self.precision else 0 for p in self.stdQ_pdf]

        if self.verbose:
            print('Finished adding epochs. Time taken:',time.time()-start)



    def add_iterations(self,N):
        ########################### Updated privacy curves by adding N iterations  #################
        stdQ_pdf_temp=fast_convolve(self.stdQ_pdf_single_iter,N)

        if self.current_epoch==0:
            self.stdQ_pdf=stdQ_pdf_temp
        else:
            self.stdQ_pdf = np.real(scipy.signal.convolve(self.stdQ_pdf,stdQ_pdf_temp,'same'))
        
        self.current_epoch+=N/math.ceil(self.num_samples/self.batch_size)



    def print_stdQ_stats(self):
        mean_Q_single,stddev_Q_single = find_mean_stddev(self.t_list,self.stdQ_pdf_single_iter)
        T=self.privacy_args.num_iter()
        print(f'Expected: mean stdQ_T={T*mean_Q_single}')
        print(f'Expected: stddev stdQ_T={math.sqrt(T)*stddev_Q_single}')

        mean_Q,stddev_Q = find_mean_stddev(self.t_list,self.stdQ_pdf)
        print(f'Convolution:  mean stdQ_T={mean_Q}')
        print(f'Convolution:  stddev stdQ_T={stddev_Q}')



    def plot_eps_delta_curve_old(self,max_eps=None,min_delta=None,CLT_approx_DPSGD=False):
        if self.current_epoch==0:
            return None
        if max_eps==None:
            max_eps=self.privacy_args.max_eps

        if min_delta==None:
            min_delta=10**(-10)


        #eps_list = list(np.arange(0,5,0.2))+list(np.arange(5,max(5,self.privacy_args.max_eps),1))
        eps_list = np.linspace(0.01,max_eps,20)

        start=time.time()
        print('Calculting eps vs delta curve...')
        delta_list=[self.calculate_delta(eps) for eps in eps_list]
        print('Finished. Time taken:',time.time()-start)


        plt.xlim(min_delta,1)

        plt.ylim(0,max_eps)
        plt.scatter(delta_list,eps_list,color='r',s=3,label='Exact fDP accountant')
        
        
        if CLT_approx_DPSGD:
            nu=math.sqrt(self.current_epoch*self.privacy_args.batch_size/self.privacy_args.num_samples)
            mu_ideal = nu*np.sqrt(np.exp(1/self.privacy_args.sigma**2)-1)

            print('mu_ideal=',mu_ideal)

            delta_list_CLT=np.geomspace(min_delta,delta_Gaussian(0,mu_ideal),100)
            eps_list_CLT=[eps_Gaussian(delta,mu_ideal) for delta in delta_list_CLT]

            plt.plot(delta_list_CLT,eps_list_CLT,color='b',label='Approximate GDP accountant')
        
        
        plt.xlabel('delta')
        plt.ylabel('eps')
        plt.xscale('log')
        plt.legend()
        plt.show()
        

    def plot_eps_delta_curve(self,max_eps=None,min_delta=None,CLT_approx_DPSGD=False, MA=False):
        if self.current_epoch==0:
            return None
        if max_eps==None:
            max_eps=self.privacy_args.max_eps

        if min_delta==None:
            min_delta=10**(-10)

        start=time.time()
        print('Calculting eps vs delta curve...')

        N = len(self.t_list)
        A = [0]*N #[Pr[Y > t] for t in t_list]
        B = [0]*N #[E[e^(-Y)1(Y>t)] for t in t_list]
        for i in reversed(range(1,N)):
            A[i-1]=A[i]+self.stdQ_pdf[i]
            B[i-1]=B[i]+math.exp(-self.t_list[i])*self.stdQ_pdf[i]

        eps_list=self.t_list
        delta_list = [A[i]-math.exp(self.t_list[i])*B[i] for i in range(N)]

        print('Finished. Time taken:',time.time()-start)


        plt.xlim(min_delta,1)

        plt.ylim(0,max_eps)
        plt.scatter(delta_list,eps_list,color='r',s=3,label='Exact fDP accountant')
        
        
        if CLT_approx_DPSGD:
            print('Calculting epsilons using GDP...')
            nu=math.sqrt(self.current_epoch*self.privacy_args.batch_size/self.privacy_args.num_samples)
            mu_ideal = nu*np.sqrt(np.exp(1/self.privacy_args.sigma**2)-1)
            print('mu_ideal',mu_ideal)
            delta_list_CLT=np.geomspace(min_delta,delta_Gaussian(0,mu_ideal),100)
            eps_list_CLT=[eps_Gaussian(delta,mu_ideal) for delta in delta_list_CLT]

            plt.plot(delta_list_CLT,eps_list_CLT,color='b',label='Approximate GDP accountant')
        
        if MA:
            print('Calculting epsilons using MA...')
            delta_list_MA=np.geomspace(min_delta,delta_Gaussian(0,mu_ideal),20)
            eps_list_MA=[compute_epsilon_MA(self.privacy_args,delta) for delta in delta_list_MA]

            plt.plot(delta_list_MA,eps_list_MA,color='g',label='MA')
        
        plt.xlabel('delta')
        plt.ylabel('eps')
        plt.xscale('log')
        plt.legend()
        plt.show()

    
    def calculate_eps(self,delta=None,tol=0.1):
        if self.current_epoch==0:
            return None

        if delta==None:
            delta=self.privacy_args.delta
            if self.calculate_delta(self.t_limit_stdPQ-10*self.privacy_args.mesh_size)>delta:
                print('ERROR: eps out of range (Increase max_eps or t_limit_stdPQ)')
                return
            else:
                eps = scipy.optimize.root_scalar(lambda x: self.calculate_delta(x)-delta,bracket=[0,self.t_limit_stdPQ-1],method='brentq').root 
                return eps


    def calculate_delta(self,eps):
        i = np.argmin(abs(np.array(self.t_list)-eps))
        if i>=len(self.t_list)-1:
            print('ERROR: eps out of range')
            return None

        if self.t_list[i]<eps:
            i+=1
        assert(self.t_list[i]>=eps)

        z=1-math.exp(eps)*self.exp_minust_list[i:]
        return np.dot(self.stdQ_pdf[i:],z)


       
##########################################################################################################################
##########################################################################################################################
##########################################################################################################################
##########################################################################################################################
##########################################################################################################################
##########################################################################################################################
##########################################################################################################################
############### Some general functions #########################

def calculate_eps_vs_epoch(privacy_engine,epochs):
    eps_list=[]
    for i in range(epochs):
        privacy_engine.add_epochs(1)
        eps_list.append(privacy_engine.calculate_eps())
    return eps_list


def G(alpha,mu):
    return norm.cdf(-mu+norm.ppf(1-alpha))

def G_derivate(alpha,mu):
    return -np.exp(norm.ppf(1-alpha)**2/2 - (norm.ppf(1-alpha)-mu)**2/2)

def delta_Gaussian(eps,mu):
    return norm.cdf(mu/2-eps/mu)-math.exp(eps)*norm.cdf(-mu/2-eps/mu)


def eps_Gaussian(delta,mu):

    if delta>=delta_Gaussian(0,mu):
        return 0

    def f(x):
        return delta_Gaussian(x,mu)-delta    
    return scipy.optimize.root_scalar(f, bracket=[0, 700], method='brentq').root




def avg_neighbors(L):
    L_avg=[]
    for i in range(len(L)-1):
        L_avg.append((L[i]+L[i+1])/2)
    return L_avg

def find_pdf_from_sf(sf):
    pdf=[]
    for i in range(0,len(sf)-1):
        pdf.append(sf[i]-sf[i+1])
    return pdf

def fast_convolve(f,n):
    if n<=0:
        print('ERROR: fast_convolve, n should be positive integer')
        return None
    if n==1:
        return f
    y=f
    ans=f
    n=n-1
    while(n>0):
        if n%2 == 1:
            ans=np.real(scipy.signal.convolve(ans,y,'same'))
        n=int(n/2)
        y=np.real(scipy.signal.convolve(y,y,'same'))
    return ans    


def find_mean_stddev(vals,pdf):
    mean = np.sum([x*p for x,p in zip(vals,pdf)])
    stddev = math.sqrt(np.sum([(x-mean)**2 * p for x,p in zip(vals,pdf)]))

    return mean,stddev


from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent

# Compute epsilon by MA
def compute_epsilon_MA(privacy_args,delta=None):
  """Computes epsilon value for given hyperparameters."""

  epoch = privacy_args.epochs
  noise_multi = privacy_args.sigma
  N = privacy_args.num_samples
  batch_size = privacy_args.batch_size
  if delta==None:
    delta=privacy_args.delta

  orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))+list(range(65,200,5))+list(range(200,1000,10))
  sampling_probability = batch_size / N
  rdp = compute_rdp(q=sampling_probability,
                    noise_multiplier=noise_multi,
                    steps=epoch*N/batch_size,
                    orders=orders)
  return get_privacy_spent(orders, rdp, target_delta=delta)[0]


#####################################################################################################################
#####################################################################################################################
#####################################################################################################################
#####################################################################################################################
#####################################################################################################################



def calculate_stdQ_single_iter_DPSGD(privacy_args,t_limit_stdPQ=None,verbose=False,interpolate_flag=True):
    
    # Input:
        
    # num_JL: Number of JL projections
    # sigma: Gaussian noise
    # batch_size: Batch size
    # num_samples: Number of samples
    # mesh_size: user smaller mesh size for better accuracy (at the cost of speed)
    # t_limit_stdPQ: Range of t for stdQ for a single iteration
    
    # Output: stdQ_pdf_list,t_list,t_limit_stdPQ (probability density function for stdQ supported on [-t_limit_sdtPQ,t_limit_stdPQ]

    ########################### Internal parameters ##########################

    print('Calculating single iteration privacy curves for DPSGD...')
    B = privacy_args.batch_size #Batch size
    N = privacy_args.num_samples #number of samples
    mesh_size=privacy_args.mesh_size
    sigma=privacy_args.sigma
    
    p=B/N #Batch sampling probability


    
   

    ################################ Calculting stdP, stdQ for a single iteration ################################


    def alpha(t):
        return scipy.stats.norm.cdf(-t/(1/sigma) - (1/sigma)/2)

    def oneminus_beta(t):
        return scipy.stats.norm.sf(t/(1/sigma) - (1/sigma)/2)

    
    max_eps=privacy_args.max_eps
    min_delta=privacy_args.min_delta
   
    f_oneminus_beta_p = lambda t: p*oneminus_beta(t+math.log(1/p-(1-p)*math.exp(-t)/p))+(1-p)*alpha(t+math.log(1/p-(1-p)*math.exp(-t)/p))

    precision = 10**(-5)*privacy_args.precision
    
    if f_oneminus_beta_p(700)>=precision:
        max_t=700
    else:
        max_t=max(max_eps,scipy.optimize.root_scalar(lambda x: f_oneminus_beta_p(x)-precision,bracket=[math.log(1-p)+10**(-15),700],method='brentq').root)
    #max_t=700

    if t_limit_stdPQ==None:
        t_limit_stdPQ=max_t 

    assert(t_limit_stdPQ>=max_t)

    if verbose:
        print('t_limit_stdPQ: ',t_limit_stdPQ)
        print('max_t: ',max_t)

    
    t_list_cdf_half = np.arange(mesh_size/2,t_limit_stdPQ,mesh_size)
    t_list_cdf = np.append(-np.flipud(t_list_cdf_half),t_list_cdf_half)

    t_list_half = avg_neighbors(t_list_cdf_half)
    t_list = np.append(np.append(-np.flipud(t_list_half),[0]),t_list_half)
    



    def oneminus_beta_p(t):
        if t >= max_t:
            return 0
        elif t>0:
            return p*oneminus_beta(t+math.log(1/p-(1-p)*math.exp(-t)/p))+(1-p)*alpha(t+math.log(1/p-(1-p)*math.exp(-t)/p))
        elif t>math.log(1-p):
            return  p*oneminus_beta(math.log((math.exp(t)-(1-p))/p))+(1-p)*alpha(math.log((math.exp(t)-(1-p))/p))
        else:
            return 1

    oneminus_beta_p_list = []

    if interpolate_flag==False:

        for t in tqdm(t_list_cdf):
            oneminus_beta_p_list.append(oneminus_beta_p(t))

        stdQ_pdf_list = find_pdf_from_sf(oneminus_beta_p_list)

    else:


        num_scales=8
        prob_scale=np.geomspace(0.01,oneminus_beta_p(max_t-1)+10**(-40),num_scales)
        #prob_scale = [10**(-2-n) for n in range(1,num_scales+1)]
        mesh_scale=np.geomspace(mesh_size,5,num_scales)
        #mesh_scale=[mesh_size*4**n for n in range(num_scales)]
        step_scale=[int(math.ceil(mesh_scale[n]/mesh_size)) for n in range(num_scales)]
        t_scale=[scipy.optimize.root_scalar(lambda x: oneminus_beta_p(x)-prob_scale[n],bracket=[math.log(1-p)+10**(-10),max_t],method='brentq').root for n in range(num_scales)]
            
        if verbose:
            print('num_scales:',num_scales)
            print('prob_scale:',prob_scale)
            print('mesh_scale:',mesh_scale)
            print('step_scale:',step_scale)
            print('t_scale:',t_scale)


        stdQ_pdf_list=[0]


        def pdf_interpolate(a0,b,n):
            if n==1:
                return [b]
            if b<=0:
                r=0
            else:
                if a0<=0:
                    a0=10**(-30)
                f=lambda r: (a0/b)*(r**n-1)*r/(r-1)-1
                #print(a0,b,(b/a0)**(1/n))
                r = scipy.optimize.root_scalar(f,bracket=[0,(b/a0)**(1/n)]).root
            return [a0*r**i for i in range(1,n+1)]

        i=1 #counter
        n=0 #scale

        b1=1
        for j in tqdm(range(1,len(t_list))):
            if i>j:
                continue
            else:
                if t_list_cdf[i]>t_scale[n] and n<num_scales-1:
                    n+=1 #increase scale


                step = min(step_scale[n],len(t_list)-i)
                assert(step>0)
                b2 = oneminus_beta_p(t_list_cdf[i+step])

                assert(len(stdQ_pdf_list)==i)

                stdQ_pdf_list+=pdf_interpolate(stdQ_pdf_list[i-1],b1-b2,step)

                b1=b2
                i+=step



    assert(len(t_list)==len(stdQ_pdf_list))



    if verbose:
        print('(should be close to 1) sum stdQ_pdf: ',sum(stdQ_pdf_list))


    mean_stdQ,stddev_stdQ = find_mean_stddev(t_list,stdQ_pdf_list)

    if verbose:
        print(f'mean_stdQ={mean_stdQ}')
        print(f'stddev_stdQ={stddev_stdQ}')


    if abs(mean_stdQ) < mesh_size:
        print('WARNING: mesh_size is larger than mean_stdQ, might affect accuracy of privacy calculations. Try reducing mesh_size.')

    print('Finished calculating privacy curves for single iteration\n')




    return stdQ_pdf_list, t_list, t_limit_stdPQ


#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################


def calculate_stdQ_single_iter_DPSGDJL(privacy_args,t_limit_stdPQ=None,verbose=False):
    
    # Input:
        
    # num_JL: Number of JL projections
    # sigma: Gaussian noise
    # batch_size: Batch size
    # num_samples: Number of samples
    # mesh_size: user smaller mesh size for better accuracy (at the cost of speed)
    # t_limit_stdPQ: Range of t for stdP, stdQ for a single iteration
    
    # Output: stdP_pdf_list, stdQ_pdf_list (probability density functions for stdP and stdQ supported on [-t_limit_sdtPQ,t_limit_stdPQ]

    ########################### Internal parameters ##########################
    

    r=privacy_args.num_JL #Number of JL projections
    if r==None:
        print('ERROR: Empty num_JL')
        return None

    print(f'Calculting single iteration privacy curves for DPSGD_JL with JLdim={r} ...')

    B = privacy_args.batch_size #Batch size
    N = privacy_args.num_samples #number of samples
    mesh_size=privacy_args.mesh_size
    sigma=privacy_args.sigma
    T = privacy_args.num_iter()
    
    p=B/N #Batch sampling probability

       ##########################################################


    #Cr = integrate.quad(lambda t: (1/t**(r+1)) * scipy.exp(-r/(2 * t**2)),0,np.inf)[0]
    #Xr_pdf= lambda t: (1/Cr)*(1/t**(r+1)) * scipy.exp(-r/(2 * t**2)) if t>0 else 0 #pdf of 1/\sqrt{(1/r)\chi_r^2}

    Xr_pdf = lambda t: scipy.stats.chi.pdf(1/t,df=r,loc=0,scale=1/math.sqrt(r))/t**2 if t>0 else 0 #pdf of 1/\sqrt{(1/r)\chi_r^2}

    sqrt2pi=np.sqrt(2*np.pi)
    invsqrt2pi=1/sqrt2pi
    ###################### Calculating standard parametrization for a single step #######################

    def alpha(t):
        return integrate.quad(lambda a: Xr_pdf(a)*scipy.stats.norm.cdf(-t/(a/sigma) - (a/sigma)/2),0,np.inf)[0]

    def beta(t):
        return integrate.quad(lambda a: Xr_pdf(a)*scipy.stats.norm.cdf(t/(a/sigma) - (a/sigma)/2),0,np.inf)[0]

    def oneminus_alpha(t):
        return integrate.quad(lambda a: Xr_pdf(a)*scipy.stats.norm.sf(-t/(a/sigma) - (a/sigma)/2),0,np.inf)[0]

    # def oneminus_beta(t):
    #   return integrate.quad(lambda a: Xr_pdf(a)*scipy.stats.norm.sf(t/(a/sigma) - (a/sigma)/2),0,np.inf)[0]

    def oneminus_beta(t):
        if t>0:
            s1=integrate.quad(lambda a: Xr_pdf(a)*scipy.stats.norm.sf(t/(a/sigma) - (a/sigma)/2),0,sigma*math.sqrt(2*t))[0]
            s2=integrate.quad(lambda a: Xr_pdf(a)*scipy.stats.norm.sf(t/(a/sigma) - (a/sigma)/2),sigma*math.sqrt(2*t),np.inf)[0]
            return s1+s2
        else:
            return integrate.quad(lambda a: Xr_pdf(a)*scipy.stats.norm.sf(t/(a/sigma) - (a/sigma)/2),0,np.inf)[0]



    max_eps=privacy_args.max_eps
    min_delta = privacy_args.min_delta
   



    print('Calculting privacy curves for single iteration...')

    
    ################################ Calculting stdP, stdQ for a single iteration ################################
    



    f_oneminus_beta_p = lambda t: p*oneminus_beta(t+math.log(1/p-(1-p)*math.exp(-t)/p))+(1-p)*alpha(t+math.log(1/p-(1-p)*math.exp(-t)/p))

    precision = 10**(-5)*privacy_args.precision
    
    if f_oneminus_beta_p(700)>=precision:
        max_t=700
    else:
        max_t=scipy.optimize.root_scalar(lambda x: f_oneminus_beta_p(x)-precision,bracket=[math.log(1-p)+10**(-10),700],method='brentq').root


    if t_limit_stdPQ==None:
        t_limit_stdPQ=max_t 

    assert(t_limit_stdPQ>=max_t)

    if verbose:
        print('t_limit_stdPQ: ',t_limit_stdPQ)
        print('max_t: ',max_t)

    t_list_cdf_half = np.arange(mesh_size/2,t_limit_stdPQ,mesh_size)
    t_list_cdf = np.append(-np.flipud(t_list_cdf_half),t_list_cdf_half)

    t_list_half = avg_neighbors(t_list_cdf_half)
    t_list = np.append(np.append(-np.flipud(t_list_half),[0]),t_list_half)

    def alpha_p(t):
        if t >= max_t:
            return 0
        elif t>0:
            return alpha(t+math.log(1/p-(1-p)*math.exp(-t)/p))
        elif t>math.log(1-p):
            return  alpha(math.log((math.exp(t)-(1-p))/p))
        else:
            return 1

    def beta_p(t):
        if t >= max_t:
            return 1
        elif t>0:
            return p*beta(t+math.log(1/p-(1-p)*math.exp(-t)/p))+(1-p)*(1-alpha(t+math.log(1/p-(1-p)*math.exp(-t)/p)))
        elif t>math.log(1-p):
            return  p*beta(math.log((math.exp(t)-(1-p))/p))+(1-p)*(1-alpha(math.log((math.exp(t)-(1-p))/p)))
        else:
            return 0

    def oneminus_beta_p(t):
        if t >= max_t:
            return 0
        elif t>0:
            return p*oneminus_beta(t+math.log(1/p-(1-p)*math.exp(-t)/p))+(1-p)*alpha(t+math.log(1/p-(1-p)*math.exp(-t)/p))
        elif t>math.log(1-p):
            return  p*oneminus_beta(math.log((math.exp(t)-(1-p))/p))+(1-p)*alpha(math.log((math.exp(t)-(1-p))/p))
        else:
            return 1




    ################# CALCULATE std_P_pdf and std_Q_pdf for a single iteration  ###################################
 


    alpha_p_list = []
    oneminus_beta_p_list = []


    num_scales=8
    prob_scale=np.geomspace(0.01,oneminus_beta_p(max_t-1),num_scales)
    #prob_scale = [10**(-2-n) for n in range(1,num_scales+1)]
    mesh_scale=np.geomspace(mesh_size,5,num_scales)
    #mesh_scale=[mesh_size*4**n for n in range(num_scales)]
    step_scale=[int(math.ceil(mesh_scale[n]/mesh_size)) for n in range(num_scales)]
    t_scale=[scipy.optimize.root_scalar(lambda x: oneminus_beta_p(x)-prob_scale[n],bracket=[math.log(1-p)+10**(-10),max_t],method='brentq').root for n in range(num_scales)]
        
    if verbose:
        print('num_scales:',num_scales)
        print('prob_scale:',prob_scale)
        print('mesh_scale:',mesh_scale)
        print('step_scale:',step_scale)
        print('t_scale:',t_scale)

    #stdP_pdf_list=[0]
    stdQ_pdf_list=[0]


    def pdf_interpolate(a0,b,n):
        if n==1:
            return [b]
        if b<=0:
            r=0
        else:
            if a0<=0:
                a0=10**(-30)
            f=lambda r: (a0/b)*(r**n-1)*r/(r-1)-1
            #print(a0,b,(b/a0)**(1/n))
            r = scipy.optimize.root_scalar(f,bracket=[0,(b/a0)**(1/n)]).root
        return [a0*r**i for i in range(1,n+1)]

    i=1 #counter
    n=0 #scale
    #a1=1
    b1=1
    for j in tqdm(range(1,len(t_list))):
        if i>j:
            continue
        else:
            if t_list_cdf[i]>t_scale[n] and n<num_scales-1:
                n+=1 #increase scale


            step = min(step_scale[n],len(t_list)-i)
            assert(step>0)
            #a2 = alpha_p(t_list_cdf[i+step])
            b2 = oneminus_beta_p(t_list_cdf[i+step])

            #assert(len(stdP_pdf_list)==i)
            assert(len(stdQ_pdf_list)==i)

            #stdP_pdf_list+=pdf_interpolate(stdP_pdf_list[i-1],a1-a2,step)
            stdQ_pdf_list+=pdf_interpolate(stdQ_pdf_list[i-1],b1-b2,step)

            #a1=a2
            b1=b2
            i+=step



    #assert(len(t_list)==len(stdP_pdf_list))
    assert(len(t_list)==len(stdQ_pdf_list))



    if verbose:
        #print('(should be close to 1) sum stdP_pdf: ',sum(stdP_pdf_list))
        print('(should be close to 1) sum stdQ_pdf: ',sum(stdQ_pdf_list))


    #mean_stdP,stddev_stdP = find_mean_stddev(t_list,stdP_pdf_list) 
    mean_stdQ,stddev_stdQ = find_mean_stddev(t_list,stdQ_pdf_list)

    if verbose:
        #print(f'mean_stdP={mean_stdP}, stddev_stdQ={stddev_stdP}')
        print(f'mean_stdQ={mean_stdQ}, stddev_stdQ={stddev_stdQ}')


    print('Finished calculating privacy curves for single iteration\n')




    return stdQ_pdf_list, t_list, t_limit_stdPQ

