import math
import numpy as np
import matplotlib.pyplot as plt
import numpy.random as rd
import sys
import time

'''
Function to generate data from a Zipfian Distribution with a bounded support
'''
def Zipf(a: np.float64, min: np.uint64, max: np.uint64, size=None):
    """
    Generate Zipf-like random variables,
    but in inclusive [min...max] interval
    """
    if min == 0:
        raise ZeroDivisionError("")

    v = np.arange(min, max+1) # values to sample
    p = 1.0 / np.power(v, a)  # probabilities
    p /= np.sum(p)            # normalized

    return np.random.choice(v, size=size, replace=True, p=p)


'''
Runs Brownian Mechanism to mask a hidden parameter
eps : list of increasing target privacy parameters
beta : parameter to hide
'''
def BM(epsSq,  beta):
    params = [0]*len(epsSq)
    params[-1] = rd.normal(beta, math.sqrt(1.0 / epsSq[-1] ))
    for i in range(len(epsSq) - 1):
        params[-(2 + i)] = params[-(i + 1)] + rd.normal(0, math.sqrt(1.0/epsSq[-(2 + i)] - 1.0/epsSq[-(1 + i)] ))
    return params

'''
Stopping Condition based on Relative Error being small enough
'''
def stopCondition(noisyVals, epsSqArray, alpha = 0.1, constant = 1.0):
    for i, y in enumerate(noisyVals):
        if np.abs( (y + constant / np.sqrt(epsSqArray[i])) / (y - constant / np.sqrt(epsSqArray[i])) ) <= 1+alpha and np.abs( (y + constant / np.sqrt(epsSqArray[i]))  / (y - constant / np.sqrt(epsSqArray[i])) )  >= 1 - alpha and np.abs(y) > 1.0 / np.sqrt(epsSqArray[i]) :
            return epsSqArray[i], y
    # Return -1 if we never satisfied the stopping condition
    return -1, noisyVals[-1]

'''
Exponential Mechanism for Counts
'''
def EM(eps, dictData):
    argMax = -1
    maxVal = -sys.maxsize
    for k in dictData:
        val = dictData[k] + np.random.gumbel(loc=0.0, scale=1.0/eps)
        if val > maxVal:
            argMax = k
            maxVal = val
    return argMax

# Check some parameter values.  We want an overall (10,1e-6)-DP guarantee
epsSqTotal = 2.705#0.77#9#2.705
deltaAdvComp = 0.000001

epsGlobal = 1.0/2.0 * epsSqTotal + np.sqrt(2 * epsSqTotal * np.log(1.0/deltaAdvComp))

print("Total DP parameters are epsilon = " + str(epsGlobal) + " and delta = " + str(deltaAdvComp))

# Set some parameters for the data distribution
upperBound = 300
zipfParameter = 0.75

# Plot Zipf Law #
v = np.arange(1, upperBound+1) # values to sample
p = 1.0 / np.power(v, zipfParameter)  # probabilities
p /= np.sum(p)

plt.bar(range(upperBound), p)
plt.title("Data Distibution")
plt.xlabel("Element")
plt.ylabel("Probability")
#plt.savefig('zipfData.pdf')
plt.show()

#####################################
######## Brownian Mechanism #########
#####################################
nTrials = 1000
# Epsilon parameter we use for the exponential mechanism
epsEM = 0.1#1.0
# Smallest epsilon squared we try
epsSqMin = 0.0001
# We check each epsilon squared in increments of epsSqDiff
#epsSqDiff = 2 * epsSqMin
nPoints = 1000
# An array of data sample sizes that we use in our experiments
sampleSizeArray = np.array([8000, 16000, 32000, 64000, 128000])#np.array([500, 1000, 2000, 4000, 8000, 16000])
# Setting our relative error threshold
alphaRelativeError = 0.1

# Initialize the arrays that will have our results
precisionMeanArrayBM = []
precisionStndDevArrayBM = []
recallMeanArrayBM = []
recallStndDevArrayBM = []

# Start a timer to see how long the experiments will take (with 1000 trials it takes about 8 min)
start = time.time()
for sampleSize in sampleSizeArray:
    nResultsBM = []
    precisionBM = []
    for i in range(nTrials):
        Zs = Zipf(zipfParameter, 1.0, upperBound, sampleSize)
        Zs.sort()

        x,y = np.histogram(Zs, bins = range(1,upperBound+1))
        histDict = dict(zip(y,x))
        epsSqCurr = 0

        tempDictBM = histDict.copy()
        resultDictBM = {}

        # We want to check to see if we have enough budget to do an exponential mechanism
        # followed by the smallest possible epsilon for noise reduction.
        while epsSqCurr + epsEM**2/8.0 + epsSqMin < epsSqTotal:
            label = EM(epsEM, tempDictBM)
            startPoint = epsSqMin
            endPoint = epsSqTotal - epsSqCurr - epsEM ** 2 / 8.0
            stepSize = (endPoint - startPoint)/nPoints # epsSqDiff
            # Make an array of epsilon squared value between epsSqMin and the remaining budget
            epsSqExPostArray = np.arange(startPoint, endPoint, stepSize)
            trueCount = tempDictBM.pop(label)
            noisyCounts = BM(epsSqExPostArray,  trueCount )
            epsSqInst, noisyVal = stopCondition(noisyCounts, epsSqExPostArray, alpha = alphaRelativeError, constant = 1.0)
            if epsSqInst > 0:
                resultDictBM[label] = noisyVal
                epsSqCurr += epsEM**2/8.0 + epsSqInst
            # Otherwise, stopping condition was never satisfied
            else:
                epsSqCurr += epsEM**2/8.0 + epsSqExPostArray[-1]
            if epsSqCurr > epsSqTotal:
                print("Over Budget!!")
            if i % 10 == 0:
                print("Sample size ", sampleSize, " Round", i, " with current epsSq ", epsSqCurr, " with true count", trueCount)



        # If we do not return anything, we say precision is 1.
        if resultDictBM == {}:
            precisionBM.append(1)
            nResultsBM.append(0)
        else:
            errorsDictBM = {}
            nGoodRelativeError = 0
            for k in resultDictBM:
                errorsDictBM[k] = abs(resultDictBM[k]) / histDict[k]
                if abs(resultDictBM[k]) / histDict[k] < 1 + alphaRelativeError and abs(resultDictBM[k]) / histDict[k] > 1 - alphaRelativeError:
                    nGoodRelativeError +=1
            nResultsBM.append(len(errorsDictBM.keys()) )
            precisionBM.append(nGoodRelativeError/len(errorsDictBM.keys()))

        #errorsDictBM

    precisionMean = np.mean(precisionBM)
    precisionStndDev = np.std(precisionBM)
    recallMean = np.mean(nResultsBM)
    recallStndDev = np.std(nResultsBM)

    precisionMeanArrayBM.append(precisionMean)
    precisionStndDevArrayBM.append(precisionStndDev)
    recallMeanArrayBM.append(recallMean)
    recallStndDevArrayBM.append(recallStndDev)
end = time.time()
print(end - start) # Takes 21 min with 1000 trials

# Precision Plot
fig, ax = plt.subplots()
ax.errorbar(range(len(sampleSizeArray)), precisionMeanArrayBM, precisionStndDevArrayBM, linestyle='None', marker='^')
ax.xaxis.set_ticks(range(len(sampleSizeArray)))
ax.xaxis.set_ticklabels(sampleSizeArray)
ax.set_ylim([0.8,1])
ax.set_ylabel("Precision")
ax.set_xlabel("Sample Size")

plt.show()

# Number of Results Returned plot
fig, ax = plt.subplots()
ax.errorbar(range(len(sampleSizeArray)), recallMeanArrayBM, recallStndDevArrayBM, linestyle='None', marker='^')
ax.xaxis.set_ticks(range(len(sampleSizeArray)))
ax.xaxis.set_ticklabels(sampleSizeArray)
ax.set_ylabel("Results Returned")
ax.set_xlabel("Sample Size")

plt.show()



###############################
##### Doubling GM #############
###############################
# Use the same parameters as in the Brownian Mechanism above #

precisionMeanArrayDGM = []
precisionStndDevArrayDGM = []
recallMeanArrayDGM = []
recallStndDevArrayDGM = []

start = time.time()
for sampleSize in sampleSizeArray:
    nResultsDGM = []
    precisionDGM = []
    for _ in range(nTrials):
        Zs = Zipf(zipfParameter, 1, upperBound, sampleSize)
        Zs.sort()

        x,y = np.histogram(Zs, bins = range(1,upperBound+1))
        histDict = dict(zip(y,x))
        epsSqCurr = 0

        tempDictDGM = histDict.copy()
        resultDictDGM = {}

        epsSqMinTemp = epsSqMin

        while epsSqCurr + epsEM**2/8.0 + epsSqMinTemp < epsSqTotal:
            label = EM(epsEM, tempDictDGM)
            # We require \sum_{i=0}^t 2^i epsSqMin < epsSqTotal.  So we need to solve for t
            endEpsSqVal = np.log( (epsSqTotal -epsSqCurr - epsEM**2/8.0)/epsSqMinTemp + 1 ) / np.log(2)
            if int( endEpsSqVal ) > 0:
                epsSqDoublingArray = [2**i * epsSqMinTemp for i in range(int( endEpsSqVal )) ]
            else:
                epsSqDoublingArray = [epsSqMinTemp]
            # Check sum(epsSqDoublingArray) should be < epsSqTotal

            trueCount = tempDictDGM.pop(label)
            noisyCounts = np.array([np.random.normal(trueCount, np.sqrt(1.0/epsSq) ) for epsSq in epsSqDoublingArray])
            epsSqInst, noisyVal = stopCondition(noisyCounts, epsSqDoublingArray, alpha = alphaRelativeError, constant = 1.0)
            if epsSqInst > 0:
                resultDictDGM[label] = noisyVal
                ind = int(epsSqDoublingArray.index(epsSqInst))
                epsSqCurr += epsEM**2/8.0 + np.sum(epsSqDoublingArray[:ind+1])
            else:
                epsSqCurr += epsEM**2/8.0 + np.sum(epsSqDoublingArray)
            if epsSqCurr > epsSqTotal:
                print("Over Budget!!")

        if resultDictDGM == {}:
            precisionDGM.append(1)
            nResultsDGM.append(0)
        else:
            errorsDictDGM = {}
            nGoodRelativeError = 0
            for k in resultDictDGM:
                errorsDictDGM[k] = abs(resultDictDGM[k]) / histDict[k]
                if abs(resultDictDGM[k]) / histDict[k] < 1 + alphaRelativeError and abs(resultDictDGM[k]) / histDict[k] > 1 - alphaRelativeError:
                    nGoodRelativeError +=1
            nResultsDGM.append(len(errorsDictDGM.keys()) )
            precisionDGM.append(nGoodRelativeError/len(errorsDictDGM.keys()))
    precisionMean = np.mean(precisionDGM)
    precisionStndDev = np.std(precisionDGM)
    recallMean = np.mean(nResultsDGM)
    recallStndDev = np.std(nResultsDGM)

    precisionMeanArrayDGM.append(precisionMean)
    precisionStndDevArrayDGM.append(precisionStndDev)
    recallMeanArrayDGM.append(recallMean)
    recallStndDevArrayDGM.append(recallStndDev)
end = time.time()
print(end - start) # Takes 6 min with 1000 trials

fig, ax = plt.subplots()
ax.errorbar(range(len(sampleSizeArray)), precisionMeanArrayDGM, precisionStndDevArrayDGM, linestyle='None', marker='^')
ax.xaxis.set_ticks(range(len(sampleSizeArray)))
ax.xaxis.set_ticklabels(sampleSizeArray)
ax.set_ylim([0.8,1])
ax.set_ylabel("Precision")
ax.set_xlabel("Sample Size")

plt.show()


fig, ax = plt.subplots()
ax.errorbar(range(len(sampleSizeArray)), recallMeanArrayDGM, recallStndDevArrayDGM, linestyle='None', marker='^')
ax.xaxis.set_ticks(range(len(sampleSizeArray)))
ax.xaxis.set_ticklabels(sampleSizeArray)
ax.set_ylabel("Results Returned")
ax.set_xlabel("Sample Size")

plt.show()


##############################
#### Put plots together ######
##############################
fig, ax = plt.subplots()
ax.errorbar(range(len(sampleSizeArray)), precisionMeanArrayBM, precisionStndDevArrayBM, capsize=10, linestyle='None', marker='^', label="BM")
ax.errorbar(range(len(sampleSizeArray)), precisionMeanArrayDGM, precisionStndDevArrayDGM, capsize=10, linestyle='None', marker='*', label="DGM")[-1][0].set_linestyle('--')
ax.xaxis.set_ticks(range(len(sampleSizeArray)))
ax.xaxis.set_ticklabels(sampleSizeArray)
ax.set_ylabel("Precision")
ax.set_xlabel("Sample Size")
ax.set_ylim([0.8,1])
plt.title("Precision with (" +str(round(epsGlobal) ) + "," + str(deltaAdvComp)  +")-DP")
ax.legend()
#plt.savefig('BMDGM_precisionResults2.pdf')
plt.show()


fig, ax = plt.subplots()
ax.errorbar(range(len(sampleSizeArray)), recallMeanArrayBM, recallStndDevArrayBM, capsize=10, linestyle='None', marker='^', label="BM")
ax.errorbar(range(len(sampleSizeArray)), recallMeanArrayDGM, recallStndDevArrayDGM, capsize=10, linestyle='None', marker='^', label="DGM")[-1][0].set_linestyle('--')
ax.xaxis.set_ticks(range(len(sampleSizeArray)))
ax.xaxis.set_ticklabels(sampleSizeArray)
plt.title("Results Returned with (" +str(round(epsGlobal) ) + "," + str(deltaAdvComp)  +")-DP")
ax.legend()
ax.set_ylabel("Results Returned")
ax.set_xlabel("Sample Size")
#plt.savefig('BMDGM_recallResults2.pdf')
plt.show()



#################
# Try Real Data #
#################

import pandas as pd

import os
os.getcwd()
df = pd.read_csv('clean_askreddit.csv')

df['clean_text'] = df['clean_text'].str.split(' ')
df = df.explode('clean_text')
df = df.groupby(['author', 'clean_text']).size().reset_index(name='count')
df = df.drop_duplicates(subset=['author', 'clean_text'])
result = df.groupby('clean_text')['author'].nunique().reset_index()
result.columns = ['clean_text', 'distinct_authors']
result = result.sort_values(by='distinct_authors', ascending=False)

# get the top-1000
domain = list(result[:1000]['clean_text'])
dataDictReddit = {}

for item in domain:
    dataDictReddit[item] = int(result[result['clean_text']==item]['distinct_authors'])

labels = list(dataDictReddit.keys())
heights = list(dataDictReddit.values())

# Plot the histogram
plt.bar(labels, heights)
plt.show()


# Check some parameter values.  We want an overall (10,1e-6)-DP guarantee
epsSqTotal = 0.0349
deltaAdvComp = 0.000001

epsGlobal = 1.0/2.0 * epsSqTotal + np.sqrt(2 * epsSqTotal * np.log(1.0/deltaAdvComp))

print("Total DP parameters are epsilon = " + str(epsGlobal) + " and delta = " + str(deltaAdvComp))


#####################################
######## Brownian Mechanism #########
#####################################
nTrials = 1000
# Epsilon parameter we use for the exponential mechanism
epsEM = 0.01
# Smallest epsilon squared we try
epsSqMin = 0.0001**2
# We check each epsilon squared in increments of epsSqDiff
#epsSqDiff = 2 * epsSqMin#0.0001
nPoints = 1000
# Setting our relative error threshold
alphaRelativeError = 0.01

# Start a timer to see how long the experiments will take (with 1000 trials it takes about 8 min)
start = time.time()
nResultsBM = []
precisionBM = []
for i in range(nTrials):
    histDict = dataDictReddit.copy()
    epsSqCurr = 0

    tempDictBM = dataDictReddit.copy()
    resultDictBM = {}

    # We want to check to see if we have enough budget to do an exponential mechanism
    # followed by the smallest possible epsilon for noise reduction.
    while epsSqCurr + epsEM**2/8.0 + epsSqMin < epsSqTotal:
        label = EM(epsEM, tempDictBM)
        # Make an array of epsilon squared value between epsSqMin and the remaining budget
        trueCount = tempDictBM.pop(label)
        startPoint = epsSqMin
        endPoint = epsSqTotal - epsSqCurr - epsEM ** 2 / 8.0
        stepSize = (endPoint - startPoint)/nPoints # epsSqDiff
        epsSqExPostArray = np.arange(epsSqMin, endPoint, stepSize)
        noisyCounts = BM(epsSqExPostArray,  trueCount )
        epsSqInst, noisyVal = stopCondition(noisyCounts, epsSqExPostArray, alpha = alphaRelativeError, constant = 1.0)
        if epsSqInst > 0:
            resultDictBM[label] = noisyVal
            epsSqCurr += epsEM**2/8.0 + epsSqInst
        # Otherwise, stopping condition was never satisfied
        else:
            epsSqCurr += epsEM**2/8.0 + epsSqExPostArray[-1]
        if epsSqCurr > epsSqTotal:
            print("Over Budget!!")
        if i % 10 == 0:
            print("Round", i, " with current epsSq ", epsSqCurr, " with true count", trueCount)


    # If we do not return anything, we say precision is 1.
    if resultDictBM == {}:
        precisionBM.append(1)
        nResultsBM.append(0)
    else:
        errorsDictBM = {}
        nGoodRelativeError = 0
        for k in resultDictBM:
            errorsDictBM[k] = abs(resultDictBM[k]) / histDict[k]
            if abs(resultDictBM[k]) / histDict[k] < 1 + alphaRelativeError and abs(resultDictBM[k]) / histDict[k] > 1 - alphaRelativeError:
                nGoodRelativeError +=1
        nResultsBM.append(len(errorsDictBM.keys()) )
        precisionBM.append(nGoodRelativeError/len(errorsDictBM.keys()))

    #errorsDictBM

precisionMean = np.mean(precisionBM)
precisionStndDev = np.std(precisionBM)
recallMean = np.mean(nResultsBM)
recallStndDev = np.std(nResultsBM)

end = time.time()
print(end - start)

print(precisionMean - precisionStndDev, precisionMean + precisionStndDev)
print(recallMean - recallStndDev, recallMean + recallStndDev)


###############################
##### Doubling GM #############
###############################
# Use the same parameters as in the Brownian Mechanism above #

start = time.time()
nResultsDGM = []
precisionDGM = []
for i in range(nTrials):

    histDict = dataDictReddit.copy()
    epsSqCurr = 0

    tempDictDGM = histDict.copy()
    resultDictDGM = {}

    epsSqMinTemp = epsSqMin

    while epsSqCurr + epsEM**2/8.0 + epsSqMinTemp < epsSqTotal:
        label = EM(epsEM, tempDictDGM)
        # We require \sum_{i=0}^t 2^i epsSqMin < epsSqTotal.  So we need to solve for t
        endEpsSqVal = np.log( (epsSqTotal -epsSqCurr - epsEM**2/8.0)/epsSqMinTemp + 1 ) / np.log(2)
        if int( endEpsSqVal ) > 0:
            epsSqDoublingArray = [2**i * epsSqMinTemp for i in range(int( endEpsSqVal )) ]
        else:
            epsSqDoublingArray = [epsSqMinTemp]
        # Check sum(epsSqDoublingArray) should be < epsSqTotal

        trueCount = tempDictDGM.pop(label)
        noisyCounts = np.array([np.random.normal(trueCount, np.sqrt(1.0/epsSq) ) for epsSq in epsSqDoublingArray])
        epsSqInst, noisyVal = stopCondition(noisyCounts, epsSqDoublingArray, alpha = alphaRelativeError, constant = 1.0)
        if epsSqInst > 0:
            resultDictDGM[label] = noisyVal
            ind = int(epsSqDoublingArray.index(epsSqInst))
            epsSqCurr += epsEM**2/8.0 + np.sum(epsSqDoublingArray[:ind+1])
        else:
            epsSqCurr += epsEM**2/8.0 + np.sum(epsSqDoublingArray)
        if epsSqCurr > epsSqTotal:
            print("Over Budget!!")
        print("Round", i, " with current epsSq ", epsSqCurr, " with true count", trueCount)

    if resultDictDGM == {}:
        precisionDGM.append(1)
        nResultsDGM.append(0)
    else:
        errorsDictDGM = {}
        nGoodRelativeError = 0
        for k in resultDictDGM:
            errorsDictDGM[k] = abs(resultDictDGM[k]) / histDict[k]
            if abs(resultDictDGM[k]) / histDict[k] < 1 + alphaRelativeError and abs(resultDictDGM[k]) / histDict[k] > 1 - alphaRelativeError:
                nGoodRelativeError +=1
        nResultsDGM.append(len(errorsDictDGM.keys()) )
        precisionDGM.append(nGoodRelativeError/len(errorsDictDGM.keys()))
precisionMean = np.mean(precisionDGM)
precisionStndDev = np.std(precisionDGM)
recallMean = np.mean(nResultsDGM)
recallStndDev = np.std(nResultsDGM)

end = time.time()
print(end - start)

print(precisionMean - precisionStndDev, precisionMean + precisionStndDev)
print(recallMean - recallStndDev, recallMean + recallStndDev)


