import argparse
import numpy as np
from scipy.stats import norm, levy_stable
from tqdm import tqdm
import matplotlib.pyplot as plt

def mul_sens(alpha, n, m, M):
    assert (alpha > 0 and alpha <= 2)
    if alpha < 1:
        return np.power(2.0, 2-2*alpha)*np.power((n-1+M)/(n-1+np.power(m-1, (alpha-1)/alpha)), alpha)
    elif alpha > 1:
        return np.power(2.0, 2*alpha-2)*np.power((n-1+M*np.power(m-1, (alpha-1)/alpha))/n, alpha)
    else:
        return 1.

def find_lam(alpha, eps, rho):

    # print(eps, rho)
    # assert (eps >= np.log(rho)/alpha)
    # print(np.log(rho)/alpha)
    if eps < np.log(rho)/alpha:
        return -1

    if alpha == 2:
        return np.sqrt(rho*(np.log(rho)+2*eps)/(rho-1))

    else:
        points = np.linspace(0, 20, num=2000, endpoint=False)
        ratios = np.divide(levy_stable.pdf(points, alpha, 0), levy_stable.pdf(points, alpha, 0, scale=np.power(rho, 1./alpha)))
        for i in range(len(ratios)):
            if ratios[i] <= np.exp(-eps):
                return points[i]

        

def get_delta(alpha, eps, rho):

    lam = find_lam(alpha, eps, rho)
    if lam is None:
        return eps, 0
    if lam == -1:
        return eps, None

    delta = 2-2*levy_stable.cdf(lam, alpha, 0)

    return eps, delta

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--alpha', type=float, default=2.0)
    parser.add_argument('--stream_size', type=int, default=10000)
    parser.add_argument('--domain_size', type=int, default=1000)
    parser.add_argument('--value_bound', type=int, default=1)
    parser.add_argument('--tick', type=float, default=0.1)
    # parser.add_argument('--rho', type=float, default=4.0)
    """
    sens: mul_sens vs. alpha, legend stream_size, fixed domain_size, value_bound
    privacy: eps vs. delta, legend alpha, fixed domain_size, stream_size, value_bound
    """
    parser.add_argument('--figure', default='sens')
    args = parser.parse_args()

    if args.figure == 'sens':
        arr = [0.05, 0.5, 0.95, 1, 1.05, 1.5, 2] # np.linspace(0.1, 2, 20)
        sens = []
        for alpha in arr:
            rho = mul_sens(alpha, args.stream_size, args.domain_size, args.value_bound)
            sens.append(rho)
        for x, y in zip(arr, sens):
            print(x, ", ", y)

    elif args.figure == 'privacy':
        arr = [1.5, 2]
        epsilons = [0.2, 0.4, 0.6, 0.8, 1.0, 2.0, 4.0]
        for alpha in arr:
            rho = mul_sens(alpha, args.stream_size, args.domain_size, args.value_bound)
            for eps in epsilons:
               #  eps = (np.ceil(np.log(rho)/alpha/args.tick)+i)*args.tick
                print(get_delta(alpha, eps, rho))

    elif args.figure == 'ratio':
        x = np.linspace(0, 100, num=101)
        y = np.divide(levy_stable.pdf(x, args.alpha, 0), levy_stable.pdf(x, args.alpha, 0, scale=2))
        print(y)
        plt.plot(x, y)
        plt.savefig('test.png')

    elif args.figure == 'eps':
        # alphas = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]
        alphas = [0.95, 1.05]
        stream_size = np.ceil(np.logspace(4, 7, num=10)).astype(int)[-1:]
        for alpha in alphas:
            print("alpha: ", alpha)
            if alpha <= 1:
                for size in stream_size:
                    rho = mul_sens(alpha, size, args.domain_size, args.value_bound)
                    print(size, np.log(rho)/alpha)
            elif alpha == 2:
                for size in stream_size:
                    rho = mul_sens(alpha, size, args.domain_size, args.value_bound)
                    epss = np.linspace(0.7, 5, num=100)
                    for eps in epss:
                        if np.sqrt(2*(rho-1)/np.pi/rho/(np.log(rho)+2*eps)) * np.exp(-2*eps*rho/(rho-1)) < 1e-5:
                            print(size, eps)
                            break
                # raise NotImplementedError
            elif alpha < 2:
                for size in stream_size:
                    rho = mul_sens(alpha, size, args.domain_size, args.value_bound)
                    x = np.linspace(2, 10, num=801)
                    y = np.divide(levy_stable.pdf(x, alpha, 0), levy_stable.pdf(x, alpha, 0, scale=np.power(rho, 1/alpha)))
                    print(size, -np.log(np.min(y)))
            else:
                raise NotImplementedError

