import numpy as np
import scipy

from NIG import NIG_rvs_single_variance
from NIW import NIW_rvs
from util import symmetrize

def get_raw_data_syn(nP, data_dim, N):

    data_prior_params, model_prior_params = setup_prior_params(data_dim)
    true_params = generate_true_params(data_prior_params, model_prior_params)

    dataset = []

    for i in range(nP):

        X, y = generate_data(true_params, N[i], data_dim)

        dataset.append({'X': X, 'y': y, 'N': N[i]})

    return dataset, true_params

def compute_sensitivity_Gauss(dataset):

    nP = len(dataset)

    # R_x = 0.
    # R_y = 0.

    # for i in range(nP):

    #     X = dataset[i]['X']
    #     y = dataset[i]['y']

    #     # Sync the sensitivities between parties
    #     R_x = max(R_x, np.sqrt(np.sum(X**2, 1)).max())  # max(||x||_2) with data_dim+1
    #     R_y = max(R_y, max(abs(y))) # max|y|

    # sensitivity = np.sqrt(1.5*R_y**4 + 2*R_x**4 + 2*(R_y*R_x)**2)

    sensitivity = 0.
    for i in range(nP):

        X = dataset[i]['X'] # X with an addition 1 column
        y = dataset[i]['y']

        R_x = np.sqrt(np.sum(X**2, 1)).max()  # max(||x||_2) with data_dim+1
        R_y = max(abs(y)) # max|y|

        if X.max()*X.min() > 0:
            t = np.sum(np.abs(X).min(axis=0)**2)    # t = \sum_i min(|x_i|)^2
        else:
            t = max(-0.5*R_y**2, -R_x**2)

        if y.max()*y.min() > 0:
            sensitivity_i = np.sqrt(2*R_x**4 - 2*t**2 + 2*(R_y*R_x)**2 - 2*(R_y**2)*t + (y.max()**2 - y.min()**2)**2)
        else:
            sensitivity_i = np.sqrt(2*R_x**4 - 2*t**2 + 2*(R_y*R_x)**2 - 2*(R_y**2)*t + R_y**4)

        sensitivity = max(sensitivity, sensitivity_i)

    return sensitivity


def generate_private_data_Gauss(dataset, sensitivity, epsilon, lamb=2):

    nP = len(dataset)
    Z = []
    S = []
    sigma_DP = np.zeros(nP)

    for i in range(nP):

        X, y = dataset[i]['X'], dataset[i]['y']
        S.append({'XX': X.T.dot(X), 'Xy': X.T.dot(y), 'yy': y.T.dot(y)[0, 0]})

        Zi, sigma_DP[i] = privatize_suff_stats_Gauss(S[i], sensitivity, epsilon[i], lamb=lamb)

        Z.append(Zi)

    return S, Z, sigma_DP

# def generate_private_data(dataset, sensitivity_x, sensitivity_y, epsilon, DP_method='Lap'):

#     nP = len(dataset)
#     Z = []
#     S = []
#     sigma_DP = np.zeros(nP)

#     for i in range(nP):

#         X, y = dataset[i]['X'], dataset[i]['y']
#         S.append({'XX': X.T.dot(X), 'Xy': X.T.dot(y), 'yy': y.T.dot(y)[0, 0]})

#         Zi, sigma_DP[i] = privatize_suff_stats(S[i], sensitivity_x, sensitivity_y, epsilon[i], DP_method)

#         Z.append(Zi)

#     return S, Z, sigma_DP

# def setup_data(data_dim, N):

#     data_prior_params, model_prior_params = setup_prior_params(data_dim)
#     true_params = generate_true_params(data_prior_params, model_prior_params)

#     X, y, sensitivity_x, sensitivity_y = generate_data(true_params, N, data_dim)

#     return data_prior_params, model_prior_params, X, y, sensitivity_x, sensitivity_y, true_params


def setup_prior_params(data_dim):

    # NIW = [mu_0, lambda_0, psi_0, nu_0] -- p(\mu, \Sigma) and p(x) = N(\mu, \Sigma)
    data_prior_params = [np.array([0] * data_dim)[:, None],
                         1,
                         np.diag([1] * data_dim),
                         50
                        ]

    # NIG = [mu_0, lambda_0, a_0, b_0] -- p(\theta, \sigma^2)
    a_0, b_0 = 20, .5   # p(\sigma^2) = InvGamma(a_0,b_0) -- in around 0.015~0.045
    c_0 = 1
    lambda_0 = b_0/(a_0 - 1) / c_0
    model_prior_params = [np.array([0] * (data_dim + 1))[:, None],    #### y = \theta*x + b (+1 is the bias)
                          np.diag([lambda_0] * (data_dim + 1)),
                          a_0,
                          b_0
                          ]

    return data_prior_params, model_prior_params


# just one sample of the parameters
def generate_true_params(data_prior_params, model_prior_params):

    theta, sigma_squared = NIG_rvs_single_variance(*model_prior_params)

    mu_x, Tau = NIW_rvs(*data_prior_params)

    if isinstance(mu_x, float):
        mu_x = np.array([mu_x])[:, None]

    true_params = {'theta': theta,
                   'sigma_squared': sigma_squared,
                   'mu_x': mu_x,
                   'Tau': Tau}

    return true_params


def generate_data(true_params, N, data_dim):

    X = scipy.stats.multivariate_normal.rvs(true_params['mu_x'].flatten(), true_params['Tau'], size=N)

    if N == 1:
        X = np.array([X])

    if data_dim == 1:
        X = X[:, None]

    # append constant bias term
    X = np.hstack((X, np.ones((N, 1))))
    y = scipy.stats.norm.rvs(X.dot(true_params['theta']), np.sqrt(true_params['sigma_squared']))

    return X, y


def privatize_suff_stats_Gauss(S, sensitivity, epsilon, seed=None, lamb=2):
    np.random.seed(seed)

    data_dim = S['XX'].shape[0]

    scale = np.sqrt((lamb*sensitivity**2)/(2*epsilon))
    Z = {key: np.random.normal(loc=val, scale=scale) for key, val in S.items()}

    # symmetrize Z_XX since we only want to add noise to upper triangle
    # Z['XX'] = symmetrize(Z['XX'])

    # To fully use the privacy budget
    # Off diagonal entry is half as noisy (2 references)
    Z['XX'] = (Z['XX'] + Z['XX'].T) / 2

    Z['X'] = Z['XX'][:, 0][:, None] # what is it used for?

    return Z, scale


# def privatize_suff_stats(S, sensitivity_x, sensitivity_y, epsilon, method='Lap'):

#     data_dim = S['XX'].shape[0]

#     XX_comps = data_dim * (data_dim - 1) / 2  # upper triangular, not counting last column which is X (should it be dim -1???)
#     X_comps = data_dim  # last column
#     Xy_comps = data_dim
#     yy_comps = 1

#     if method == 'Lap':

#         sensitivity = XX_comps * max(sensitivity_x[:-1]) ** 2 \
#                       + X_comps * max(sensitivity_x[:-1]) \
#                       + Xy_comps * max(sensitivity_x[:-1]) * sensitivity_y \
#                       + yy_comps * sensitivity_y ** 2

#         scale = sensitivity / epsilon
#         Z = {key: np.random.laplace(loc=val, scale=scale) for key, val in S.items()}

#     elif method=='Gaussian':

#         # L2 sensitivity
#         sensitivity = np.sqrt(XX_comps * max(sensitivity_x[:-1]) ** 4 \
#                       + X_comps * max(sensitivity_x[:-1]) ** 2 \
#                       + Xy_comps * (max(sensitivity_x[:-1]) * sensitivity_y) ** 2 \
#                       + yy_comps * sensitivity_y ** 4)

#         lamb = 2
#         scale = lamb*sensitivity/(2*epsilon)
#         Z = {key: np.random.normal(loc=val, scale=scale) for key, val in S.items()}

#     else:
#         raise ValueError('Invalid DP method.')


#     # symmetrize Z_XX since we only want to add noise to upper triangle
#     Z['XX'] = symmetrize(Z['XX'])

#     Z['X'] = Z['XX'][:, 0][:, None] # what is it used for?

#     return Z, scale
