import torch
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import argparse
import pdb
from time import time
import numba
from numba import njit, prange
from scipy.special import comb
from scipy.spatial.distance import cdist
import pdb
import os
os.environ["OMP_NUM_THREADS"] = "24"

random.seed(111111)

CONST = 1e-64

def sigmoid(x):
 return 1/(1 + np.exp(-x))

def sigmoid_schedule(t, start=0, end=3, tau=0.3, clip_min=1e-9):
    # A gamma function based on sigmoid function.
    v_start = sigmoid(start / tau)
    v_end = sigmoid(end / tau)
    output = sigmoid((t * (end - start) + start) / tau)
    output = (v_end - output) / (v_end - v_start)
    return np.clip(output, clip_min, 0.99)

def cosine_schedule(t, T, s = 0.01, clip_min=1e-9):
    f = np.cos((t/T + s) / (1 + s) * np.pi * 0.5)
    alpha_t = f / np.cos((s) / (1 + s) * np.pi)
    return np.clip(alpha_t, clip_min, 0.99)

def reciprocal_schedule(t, T, clip_min=1e-9):
    return np.clip(1 - 1 / (T - t + 1), clip_min, 0.99)

# @njit
def find_matching_indices(data, sample_index):
    matching_indices = []
    for i in range(data.shape[1]):
        indices = np.where(data[:, i] == data[sample_index, i])[0]
        matching_indices.append(indices[:])
    return matching_indices

# @njit
def s_psi_t_func(sample_distance_table, sample_index, data, A_t, B_t, R_bar, t, n):
    sim = np.power(R_bar[t+1], -sample_distance_table).sum() - 1
    index = find_matching_indices(data, sample_index)
    bound_value = []
    for i in range(n):
        sim_value = np.power(R_bar[t+1], -sample_distance_table[index[i]]).sum() - 1
        bound_value.append(B_t/(1 + sim + R_bar[t]**2 * sim_value))
        # print('sim: '+str(sim))
        # print('R**2 * sim_value: ' + str(R_bar[t]**2 * sim_value))
    bound_value_array = np.array(bound_value)
    return max(A_t / (1 + sim) * math.log(1 + 1 / n * np.sum(bound_value_array)), 1e-32)

# Algorithm 1
# @njit
def eta_to_eta_star(eta, sample_index, N_table, mu_bar_minus, s, n):
    x = N_table[sample_index, min(2*eta, n-1)]
    if x/(s-x+CONST) > (2*math.e*mu_bar_minus)**eta:
        den = math.log(1/(mu_bar_minus + CONST)+CONST)
        for eta_star in range(eta, min(2*eta, n+1)):
            neta_t = N_table[sample_index, eta_star]
            if eta_star/(eta + CONST)-1 >= (1/(eta + CONST)*math.log((s-neta_t)/neta_t + CONST)+1+math.log(2))/den:
                return eta_star
    else:
        den = math.log(1/(mu_bar_minus+CONST))-1+CONST
        for eta_star in range(2*eta, n+1):
            neta_t = N_table[sample_index, eta_star]
            if eta_star/(eta + CONST)-1 >= (1/(eta + CONST)*math.log((s-neta_t)/(neta_t+CONST) + CONST))/den:
                return eta_star
    return n

def binomial_upper_cumulative_prob(n, p, a):
    """
    Calculate the value of P(S_n >= a)
    :param n: Number of trials
    :param p: Probability of X_i being 1
    :param a: Starting value
    :return: Value of P(S_n >= a)
    """
    q = 1 - p
    total_prob = 0

    for k in range(a, n + 1):
        total_prob += comb(n, k) * (p ** k) * (q ** (n - k))

    return total_prob

def find_smallest_a(n, p, c):
    """
    Find the smallest a such that c >= P(S_n >= a)
    :param n: Number of trials
    :param p: Probability of X_i being 1
    :param c: Given threshold
    :return: Smallest value of a
    """
    a = 0
    while True:
        prob = binomial_upper_cumulative_prob(n, p, a)
        if c >= prob:
            return a
        a += 1

def find_optimal_eta_etaprime_gap(eta, eta_star, N_table, n, s, mu_bar_plus_t):
    eta_gap = find_smallest_a(n, 1 - mu_bar_plus_t, N_table[eta_star] / (s - N_table[eta]))
    return eta_gap

def main(p):
    parser = argparse.ArgumentParser(description="Arg Parse for Diffusion Model privacy")
    parser.add_argument('--n', dest='n', type=int,
                        default=5, help='feature dimension')
    parser.add_argument('--k', dest='k', type=int,
                        default=5, help='categories')
    # parser.add_argument('--p', dest='p', type=float,
    #                     default=0.5, help='skewness')
    parser.add_argument('--s', dest='s', type=int,
                        default=1001, help='number of samples')
    parser.add_argument('--t', dest='t', type=int,
                        default=20, help='steps of diffusion')
    parser.add_argument('--epsilon', dest='epsilon', type=int, default=1)
    parser.add_argument('--compress', action='store_true',
                        help='enable compression')
    parser.add_argument('--label', type=int, default=1)
    args = parser.parse_args()

    # args.compress = True
    # args.load = True

    # Parameters of Skewed Distribution
    k = args.k
    # p = args.p
    q = (1 - p) / (k-1)

    # Degree of Skewness
    p_prob = [p]+[q]*(k-1)

    # Number of Samples: s
    s = args.s
    n = args.n

    # # Generate Samples
    # data = np.random.choice(np.arange(k), size=(s, n), p=p_prob)

    # Generate Samples - v_star
    data_star = np.full((1,n), 3)
    data_1 = np.random.choice(np.arange(k), size=(s-1, n), p=p_prob)
    data = np.concatenate((data_star, data_1), axis = 0)

    # Diffusion Step
    T = args.t

    # Diffusion coefficients
    ## 1.Linear Schedule
    ### coeff: diffusion coefficients list at each time step
    # coeff = 1-1e-2-np.arange(1, T+1)*(1-1e-2)/T

    # coeff = np.array([1 - 1e-7, 1-1e-6, 1-1e-5, 1-1e-3, 1-1e-1])

    ##2. Sigmoid Schedule
    # coeff = [sigmoid_schedule(t / T, start=0, end=3, tau=0.4, clip_min=1e-9) for t in range(T)]
    # coeff = np.array(coeff)


    # ##3.Cosine Schedule
    coeff = [cosine_schedule(t, T, s=0.01) for t in range(T)]
    coeff = np.array(coeff)


    # ##4.Reciprocal Schedule
    # coeff = [reciprocal_schedule(t, T, clip_min=1e-9) for t in range(T)]
    # coeff = np.array(coeff)


    # Compute cumprod of diffusion coefficients
    alpha_list = np.log(coeff + CONST)
    log_sum = np.cumsum(alpha_list)
    ### bar_coeff: cumprod of diffusion coefficients list
    bar_coeff = np.exp(log_sum)

    # Define mu_plus, mu_minus, mu_bar_plus, mu_bar_minus and R_bar list
    mu_plus = (1 + (k-1) * coeff) / k
    mu_minus = (1 - coeff) / k
    mu_plus = np.insert(mu_plus, 0, 1)
    mu_minus = np.insert(mu_minus, 0, CONST)

    mu_bar_plus = (1 + (k-1) * bar_coeff) / k
    mu_bar_minus = (1 - bar_coeff) / k
    mu_bar_plus = np.insert(mu_bar_plus, 0, 1)
    mu_bar_minus = np.insert(mu_bar_minus, 0, CONST)

    R_bar = (1 + (k-1) * bar_coeff) / (1 - bar_coeff)
    element_to_add = (1 + (k-1) * coeff[0]) / (1 - coeff[0]) + 100
    R_bar = np.insert(R_bar, 0, element_to_add)

    # A, B lists to compute s_psi_t
    A = mu_plus[1:] * (mu_bar_plus[:-1] / mu_bar_plus[1:] - mu_bar_minus[:-1] / mu_bar_minus[1:])
    B = R_bar[:-1]**2 - 1

    # rho
    term_1 = (1 - mu_plus[1:] * mu_bar_plus[:-1] / mu_bar_plus[1:])
    term_2 = (1 - 1 / R_bar[:-1])
    term_3 = mu_plus[1:] * mu_bar_plus[:-1] / mu_bar_plus[1:]
    term_4 = (1 - R_bar[1:] / R_bar[:-1])
    rho = np.array([term_1 * term_2 + term_3 * term_4])

    if args.compress:
        distance_table = np.zeros(s*(s+1)//2, dtype=np.int32)

        @njit
        def mat2vec(i, j):
            return i*s-(i-1)*i//2+j-i if i <= j else j*s-(j-1)*j//2+i-j

        @njit
        def get_row(i):
            k = i*(s-1)-(i-1)*i//2
            idx = np.arange(k, k+s)
            for j in range(i):
                idx[j] = mat2vec(i, j)
            return idx

        @njit(parallel=True)
        def hamming(distance):
            for i in prange(s):
                for j in range(i+1, s):
                    distance[mat2vec(i, j)] = (data[i] != data[j]).sum()
    else:
        distance_table = np.zeros((s, s), dtype=np.int32)

        @njit(parallel=True)
        def hamming(distance):
            for i in prange(s):
                for j in range(i+1, s):
                    distance[i, j] = distance[j, i] = (
                        data[i] != data[j]).sum()
                    
    sta = time()
    # Calculating the hamming distance matrix s * s
    hamming(distance_table)
    print(f'dt distance_parallel {time()-sta:.2f}s')

    # Compute the distribution of samples distance s * (n + 1)
    N_table = np.zeros((s, n+1), dtype=np.int32)

    @njit(parallel=True)
    def freq(table, source):
        for i in prange(s):
            source_i = source[i]
            for j in range(s):
                table[i, source_i[j]] += 1
    sta = time()
    freq(N_table, distance_table)
    print(f'dt freq_count_parallel {time()-sta:.2f}s')
    N_table = np.cumsum(N_table, axis=1)

    # find_eta: data-dependent privacy algorithm to search for the optimal eta_t^* list
    # @njit(parallel=True)
    def find_eta(array_eta_star, T, distance_table, N_table, s_psi, s, n, A, B, R_bar, mu_bar_plus, mu_bar_minus):
        # for sample_index in prange(s):
        for t in prange(0, T-1):
            s_psi[0, t] = s_psi_t_func(distance_table[0], 0, data, A[t+1], B[t+1], R_bar, t, n)
            eta_list = []
            for eta in range(0, n):
                eta_star = eta_to_eta_star(eta, 0, N_table, mu_bar_minus[t+1], s, n)
                x, y = N_table[0, eta], N_table[0, eta_star]
                num = math.log((s-x)/y + CONST)
                eta_gap = find_optimal_eta_etaprime_gap(eta, eta_star, N_table[0], n, s, mu_bar_plus[t])
                if eta_gap >= + max((num + math.log(max(A[t+1] * B[t+1] * s_psi[0, t], 1e-32) / R_bar[t]**2)) / (2*math.log(R_bar[t+1]) + CONST), 0) - 2:
                    eta_list.append(eta_star)
            array_eta_star[0, t] = min(eta_list)

    array_eta_star = np.zeros((1, T-1), dtype=np.int8)
    s_psi = np.zeros((1, T-1))

    sta = time()
    find_eta(array_eta_star, T, distance_table, N_table, s_psi, s, n, A, B, R_bar, mu_bar_plus, mu_bar_minus)
    print(f'dt eta_star_parallel {time()-sta:.2f}s')

    # #print total privacy
    # privacy_1 = n*(np.minimum(4 * np.take_along_axis(N_table, array_eta_star, axis=1) / s, 1) * s_psi).sum(axis=1)
    # privacy_2 = (rho / s**2).sum(axis=1)

    # privacy = privacy_1 + privacy_2

    # print(f'Privacy list: {privacy}')
    # print('*****************************************************************')
    # print(f'Most private sample: {privacy.max()}')
    # print(f'Least private sample: {privacy.min()}')

    # # DP Framework
    # epsilon = args.epsilon
    # DP_list = privacy / epsilon / (1 - math.pow(math.e, -epsilon))
    # print('DP matrix: ' + str(DP_list))
    # print('DP max: ' + str(max(DP_list)))
    # print('DP min: ' + str(min(DP_list)))
    # print('DP avg: ' + str(np.mean(DP_list)))

    ### print each step privacy
    privacy_step_1 = n*(np.minimum(4 * np.take_along_axis(N_table[0], array_eta_star[0], axis = 0) / s, 1) * s_psi[0])
    privacy_step_2 = (rho / s**2).flatten()[1:]
    privacy_step = privacy_step_1 + privacy_step_2

    # print(f'Privacy matrix: {privacy}')
    # print('*****************************************************************')

    # DP Framework
    epsilon = args.epsilon
    DP_matrix = privacy_step / epsilon / (1 - math.pow(math.e, -epsilon))
    # print('DP matrix: ' + str(DP_matrix))
    # print('DP max: ' + str(max(DP_matrix)))
    # print('DP min: ' + str(min(DP_matrix)))
    DP_avg_step = DP_matrix
    print('Skewness p: ' + str(p) + ' DP avg: ' + str(DP_avg_step))

    s_psi_avg_step = s_psi[0]
    print('s_psi: ' + str(s_psi_avg_step))

    N_over_s_avg_step = np.zeros(T-1)
    for t in range(T-1):
        N_over_s_avg_step[t] = N_table[0, array_eta_star[0, t]] / s
    print('array_eta_star: ' + str(array_eta_star[0]))

    return DP_avg_step, s_psi_avg_step, N_over_s_avg_step, array_eta_star.flatten()

if __name__ == '__main__':
    #Exp 1. privacy and data quantity in each step
    skew_list = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2])
    avg_privacy_step = []
    std_privacy_step = []
    avg_eta_star_step = []
    std_sta_star_step = []
    avg_n_over_s_step = []
    std_n_over_s_step = []
    avg_s_psi_step = []
    std_s_psi_step = []
    runs = 5
    for p in skew_list:
        dp_tmp = []
        psi_tmp = []
        n_tmp = []
        star_tmp = []
        for _ in range(runs):
            DP_avg_step, s_psi_avg_step, N_over_s_avg_step, array_eta_star = main(p)
            dp_tmp.append(DP_avg_step)
            psi_tmp.append(s_psi_avg_step)
            n_tmp.append(N_over_s_avg_step)
            star_tmp.append(array_eta_star)

        dp_tmp_array = np.array(dp_tmp)
        psi_tmp_array = np.array(psi_tmp)
        n_tmp_array = np.array(n_tmp)
        star_tmp_array = np.array(star_tmp)

        avg_privacy_step.append(np.mean(dp_tmp_array, axis = 0))
        avg_s_psi_step.append(np.mean(psi_tmp_array, axis = 0))
        avg_n_over_s_step.append(np.mean(n_tmp_array, axis = 0))
        avg_eta_star_step.append(np.mean(star_tmp_array, axis = 0))
        std_privacy_step.append(np.std(dp_tmp_array, axis = 0))
        std_s_psi_step.append(np.std(psi_tmp_array, axis = 0))
        std_n_over_s_step.append(np.std(n_tmp_array, axis = 0))
        std_sta_star_step.append(np.std(star_tmp_array, axis = 0))
    # save data
    data_dict_privacy = {str(skew): arr for skew, arr in zip(skew_list, avg_privacy_step)}
    current_path = os.getcwd()
    filename_1 = os.path.join(current_path, 'skewness_each_step', 'avg_privacy_data.npz')
    np.savez(filename_1, **data_dict_privacy)
    # save data
    data_dict_N = {str(skew): arr for skew, arr in zip(skew_list, avg_n_over_s_step)}
    current_path = os.getcwd()
    filename_2 = os.path.join(current_path, 'skewness_each_step', 'avg_N_data.npz')
    np.savez(filename_2, **data_dict_N)
    # save data
    data_dict_psi = {str(skew): arr for skew, arr in zip(skew_list, avg_s_psi_step)}
    current_path = os.getcwd()
    filename_3 = os.path.join(current_path, 'skewness_each_step', 'avg_psi_data.npz')
    np.savez(filename_3, **data_dict_psi)
    # save data
    data_dict_privacy_std = {str(skew): arr for skew, arr in zip(skew_list, std_privacy_step)}
    current_path = os.getcwd()
    filename_4 = os.path.join(current_path, 'skewness_each_step', 'std_privacy_data.npz')
    np.savez(filename_4, **data_dict_privacy_std)
    # save data
    data_dict_N_std = {str(skew): arr for skew, arr in zip(skew_list, std_n_over_s_step)}
    current_path = os.getcwd()
    filename_5 = os.path.join(current_path, 'skewness_each_step', 'std_N_data.npz')
    np.savez(filename_5, **data_dict_N_std)
    # save data
    data_dict_psi_std = {str(skew): arr for skew, arr in zip(skew_list, std_s_psi_step)}
    current_path = os.getcwd()
    filename_6 = os.path.join(current_path, 'skewness_each_step', 'std_psi_data.npz')
    np.savez(filename_6, **data_dict_psi_std)
    # save data
    data_dict_star = {str(skew): arr for skew, arr in zip(skew_list, avg_eta_star_step)}
    current_path = os.getcwd()
    filename_7 = os.path.join(current_path, 'skewness_each_step', 'avg_star_data.npz')
    np.savez(filename_7, **data_dict_star)
    # save data
    data_dict_star_std = {str(skew): arr for skew, arr in zip(skew_list, std_sta_star_step)}
    current_path = os.getcwd()
    filename_8 = os.path.join(current_path, 'skewness_each_step', 'std_star_data.npz')
    np.savez(filename_8, **data_dict_star_std)