import math
import pandas as pd
import numpy as np
import random


### Generate the noisy sensitive attributes from exact sensitive attributes
def privatization(exact_attr, epsilon):
    # compute conditional probabilities under the epsilon locally differentially mechanism
    pi = math.exp(epsilon) / (len(set(exact_attr)) - 1 + math.exp(epsilon))
    pi_bar = 1 / (len(set(exact_attr)) - 1 + math.exp(epsilon))

    # compute the weights matrix used to generate the noisy sensitive attributes
    num_uniq_attr = len(np.unique(exact_attr))
    weights_matrix = np.empty((num_uniq_attr, num_uniq_attr), dtype = float)
    for i in range(num_uniq_attr):
        for j in range(num_uniq_attr):
            if i == j:
                weights_matrix[i][j] = pi
            elif i != j:
                weights_matrix[i][j] = pi_bar
    w = weights_matrix

    # generate noisy sensitive attributes using the weights matrix
    noisy_attr = list(exact_attr)
    for i in range(exact_attr.shape[0]):
        idx = np.where(exact_attr[i] == np.unique(exact_attr))
        noisy_attr[i] = random.choices(np.unique(exact_attr), weights = w[int(idx[0])], k = 1)
    noisy_attr = pd.Series(np.array(noisy_attr).flatten())
    Z = noisy_attr

    return(Z)

def add_noise(A, epsilon):
    l = len(np.unique(A))
    pi, pi_bar = compute_pi(A = l, epsilon = epsilon)
    A = A.copy().to_numpy()
    levels = np.unique(A)
    A[A == levels[0]] = 0
    A[A == levels[1]] = 1
    coin = np.random.binomial(n = 1, p = pi_bar, size = A.shape[0])
    Z = abs(A - coin)         
    Z[Z == 0] = 'Female'
    Z[Z == 1] = 'Male'
    return Z


### Project vector onto a simplex of corresponding dimensions. (Used when estimated P(S) contains negative probabilities.)
def proj_simplex(v, z = 1):
    n_features = v.shape[0]
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u) - z
    ind = np.arange(n_features) + 1
    cond = u - cssv / ind > 0
    rho = ind[cond][-1]
    theta = cssv[cond][-1] / float(rho)
    w = np.maximum(v - theta, 0)
    return w


### Compute the conditional probability of S given D
def compute_pi(A, epsilon):
    pi = math.exp(epsilon) / (A - 1 + math.exp(epsilon))
    pi_bar = (1 - pi) / (A - 1)
    return pi, pi_bar

### Compute the inverse of matrix T
def compute_inverse_T(A, pi):
    inv_T_ii = (pi + A - 2) / (A * pi - 1)
    inv_T_ij = (pi - 1) / (A * pi - 1) 
    inv_T = np.empty((A, A))

    for i in range(inv_T.shape[0]):
        for j in range(inv_T.shape[1]):
            if i == j:
                inv_T[i][j] = inv_T_ii
            else: 
                inv_T[i][j] = inv_T_ij

    return inv_T

### Compute the inverse of matrix Pi
def compute_inverse_Pi(A, Z, pi, pi_bar):
    # estimate the probability density of Z
    p_z = [0] * len(set(A))
    p_a = [0] * len(set(A))

    for i in range(len(np.unique(Z))):
        p_z[i] = sum(np.unique(Z)[i] == Z) / len(Z)
        p_a[i] = sum(np.unique(Z)[i] == Z) / len(Z)

    # Compute the inverse of matrix S
    inv_T = compute_inverse_T(A = len(set(A)), pi = pi)

    p_a_proj = inv_T.dot(p_z)
    
    # compute P matrix
    Pi = np.zeros(len(set(A))**2).reshape((len(set(A)),len(set(A))))

    for i in range(Pi.shape[0]):
        for j in range(Pi.shape[1]):
            if i == j:
                Pi_ii = (pi * p_a_proj[i]) / (pi * p_a_proj[i] + pi_bar * (1 - p_a_proj[i]))
                Pi[i][j] = Pi_ii
            else:    
                Pi_ij = (pi_bar * p_a_proj[j]) / (pi * p_a_proj[i] + pi_bar * (1 - p_a_proj[i]))
                Pi[i][j] = Pi_ij
    # compute the inverse of P matrix
    inv_Pi = np.linalg.inv(Pi).astype('float32')
    # print('p_s', p_z , 'sum = 1?', sum(p_z) == 1)
    # print('p_d', p_a_proj, 'sum = 1?', sum(p_a) == 1)
    # print('Inv_P', inv_Pi)


    return inv_Pi

### Compute the estimator of C_1
def compute_C_1_hat(A, pi_hat):
    C_1_hat = (pi_hat + A - 2) / (A * pi_hat - 1)
    return C_1_hat

## compute inverse of R matrix
def compute_inv_R(X_F, X_M, Z, pi):
    # compute P(X = x, D = k)
    A = Z.shape[1]
    inv_T = compute_inverse_T(A = A, pi = pi)
    p_xaF = inv_T[0][0] * X_F + inv_T[0][1] * X_M
    p_xaM = inv_T[1][0] * X_F + inv_T[1][1] * X_M

    c_0 = (pi + A - 2) / (A * pi - 1)
    c_1 = (pi - 1) / (A * pi - 1)

    inv_R_00 = c_0 * (X_F / p_xaF)
    inv_R_01 = c_1 * (X_M / p_xaF)
    inv_R_10 = c_1 * (X_F / p_xaM)
    inv_R_11 = c_0 * (X_M / p_xaM)

    return inv_R_00, inv_R_01, inv_R_10, inv_R_11


def get_weight_dict(df,secret_tag):
    N = len(df)
    weight_dict={}
    for s in df[secret_tag].unique():
        weight_dict[s]=N/(len(df.loc[df[secret_tag]==s]))
    return weight_dict