import os
import sys
sys.path.append(os.getcwd() + '/linreg/')
sys.path.append(os.getcwd() + '/logreg/')

from generate_data import get_raw_data_syn, compute_sensitivity_Gauss, generate_private_data_Gauss, privatize_suff_stats_Gauss

from log_generate_data import generate_approx_ss_noisy
from PASS import PASS

import dill
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_squared_error, r2_score, log_loss
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, StandardScaler
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def _preprocess_calih_for_nn():
    """Apply standard scaler on both X and y"""
    X, y = fetch_california_housing(return_X_y = True)

    X_nn, X_rest, y_nn, y_rest = train_test_split(X, y, test_size=0.4)

    xscaler = RobustScaler(unit_variance=True)
    yscaler = RobustScaler(unit_variance=True)

    xscaler.fit(X_nn)
    yscaler.fit(y_nn.reshape(-1,1))
    return xscaler.transform(X_nn), xscaler.transform(X_rest), yscaler.transform(y_nn.reshape(-1,1)), yscaler.transform(y_rest.reshape(-1,1))

def _train_nn_transformer(intermediate_dims, X, y, X_test, y_test, epochs=200):
    nn_layers = []
    nn_layers.append(layers.Dense(intermediate_dims[0], activation=tf.nn.relu, input_shape=[X.shape[1]]))
    for size in intermediate_dims[1:]:
        nn_layers.append(layers.Dense(size, activation=tf.nn.relu, activity_regularizer=tf.keras.regularizers.L2(0.005)))
    nn_layers.append(layers.Dense(1, kernel_regularizer=tf.keras.regularizers.l2(0.005), #was 0.005
                                  bias_regularizer=tf.keras.regularizers.l2(0.005)))

    model = keras.models.Sequential(nn_layers)
    model.compile(loss='huber', optimizer=tf.keras.optimizers.Adam())

    early_stopping_cb = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=25, restore_best_weights=True)
    model.fit(X, y, epochs=epochs, batch_size=256, shuffle=True, verbose=0,
            validation_data=(X_test, y_test), callbacks=[early_stopping_cb])
    model.evaluate(X_test, y_test)
    print("Transfer NN: Train R-squared", r2_score(y, model.predict(X)))
    print("Transfer NN: Test R-squared", r2_score(y_test, model.predict(X_test)))

    transformer = keras.models.Model(model.input, model.get_layer(index=-2).output)

    return transformer

def generate_cali():
    nP = 3
    data_dim = 6

    # privacy setting for each party
    epsilon = np.array([0.2,0.1,0.2])
    DP_method = 'Gaussian'


    X_nn, X_rest, y_nn, y_rest = _preprocess_calih_for_nn()
    X_train, X_test, y_train, y_test = train_test_split(X_rest, y_rest, test_size=0.2)
    transformer = _train_nn_transformer([data_dim*8,data_dim*4,data_dim], X_nn, y_nn, X_test, y_test)

    xscaler = StandardScaler()
    xscaler.fit(transformer.predict(X_nn))
    def process(M):
        return np.hstack((xscaler.transform(transformer.predict(M)), np.ones((M.shape[0], 1))))
    p_X_nn, p_X_train, p_X_test = process(X_nn), process(X_train), process(X_test)
    mask = (np.abs(p_X_train) <= 4).all(axis=1)
    p_X_train = p_X_train[mask]
    y_train = y_train[mask]

    # Data prior parameters
    scatter = np.cov(p_X_nn.T, bias=True)[:-1,:-1]
    lam0 = np.diag(np.diag(scatter)) #prior for covariance
    nu0 = data_dim + 2 #belief strength for covariance
    mu0 = np.mean(p_X_nn[:,:-1], axis=0) #prior for mean
    kappa0 = 0.1 # belief strength for mean

    data_prior_params = [mu0[:, None],
                         kappa0,
                         lam0,
                         nu0
                        ]

    mle_w = np.linalg.inv(p_X_nn.T @ p_X_nn) @ p_X_nn.T @ y_nn
    print("MLE weights, should be near  0", mle_w.ravel())
    mle_s2 = ((y_nn - p_X_nn @ mle_w)**2).sum() /(len(y_nn) - 1)

    # https://www.colorado.edu/amath/sites/default/files/attached-files/bayes_linreg.pdf
    blr_n = 10 # number of fake observations
    blr_nu0 = blr_n / 2
    blr_scale0 = (blr_nu0 * mle_s2)
    blr_lambd0 = (p_X_nn.T @ p_X_nn) * blr_n / len(y_nn)

    model_prior_params = [np.array([0] * (data_dim + 1))[:, None],
                          blr_lambd0,
                          blr_nu0,
                          blr_scale0
                          ]

    X_3, X_12, y_3, y_12 = train_test_split(p_X_train, y_train, test_size=0.5)
    X_1, X_2, y_1, y_2 = train_test_split(X_12, y_12, test_size=0.6)
    print("Dataset sizes of party 0, 1, 2, combined:", len(y_1), len(y_2), len(y_3), len(p_X_train))

    def create_dataset(Xs, ys):
        dataset = []
        Ns = []
        nP = len(ys)
        for i in range(nP):
            dataset.append({'X': Xs[i], 'y': ys[i], 'N': len(ys[i])})
            Ns.append(len(ys[i]))
        return dataset, Ns

    dataset, N = create_dataset([X_1, X_2, X_3, p_X_test], [y_1, y_2, y_3, y_test])
    test_dataset = dataset.pop()
    test_N = N.pop()
    N = np.array(N)

    sensitivity = compute_sensitivity_Gauss(dataset)
    S, Z, sigma_DP = generate_private_data_Gauss(dataset, sensitivity, epsilon)

    nSample = 30000

    settings = nP, N, epsilon, DP_method, data_dim, dataset, test_dataset, sensitivity, S, Z, sigma_DP, data_prior_params, model_prior_params, nSample

    with open("calih{}.pkl".format(data_dim), "wb") as f:
        dill.dump(settings, f)
        dill.dump((p_X_nn, p_X_train, p_X_test, y_nn, y_train, y_test), f)

def generate_syn():
    nP = 3  # number of parties
    N = np.array([400,200,100]) # number of data for each party

    # privacy setting for each party
    epsilon = np.array([0.2,0.1,0.2])
    DP_method = 'Gaussian'

    # dimension of covariate data
    data_dim = 2

    dataset, true_params = get_raw_data_syn(nP+1, data_dim, np.concatenate((N,[np.sum(N)])))
    test_dataset = dataset.pop()

    # Only designed for L2 sensitivity of LR sufficient statistics
    sensitivity = compute_sensitivity_Gauss(dataset)

    S, Z, sigma_DP = generate_private_data_Gauss(dataset, sensitivity, epsilon)

    ### Prior parameters

    # NIW = [mu_0, lambda_0, psi_0, nu_0] -- p(\mu, \Sigma) and p(x) = N(\mu, \Sigma)
    data_prior_params = [np.array([0] * data_dim)[:, None],
                         1, # previously 10, lower sigma
                         np.diag([1] * data_dim),
                         50
                        ]

    # NIG = [mu_0, lambda_0, a_0, b_0] -- p(\theta, \sigma^2)
    a_0, b_0 = 5, .1   # p(\sigma^2) = InvGamma(a_0,b_0) -- in around 0.015~0.045
    c_0 = 1
    lambda_0 = b_0/(a_0 - 1) / c_0
    model_prior_params = [np.array([0] * (data_dim + 1))[:, None],    #### y = \theta*x + b (+1 is the bias)
                          np.diag([lambda_0] * (data_dim + 1)),
                          a_0,
                          b_0
                          ]

    nSample = 30000

    settings = nP, N, epsilon, DP_method, data_dim, dataset, test_dataset, sensitivity, S, Z, sigma_DP, data_prior_params, model_prior_params, nSample

    with open("syn.pkl", "wb") as f:
        dill.dump(settings, f)
        dill.dump(true_params, f)


def _load_diabetes(seed=27, d=4, clip=None):
    df = pd.read_csv("datasets/diabetes.csv")
    y = df.pop("Outcome").to_numpy()
    X = df.to_numpy()

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

    print("Training data dimensions:", X_train.shape)

    robust = RobustScaler().fit(X_train)
    X_train = robust.transform(X_train)
    pca = PCA(n_components=d, random_state=seed).fit(X_train)
    X_train = pca.transform(X_train)
    post_scaler = StandardScaler().fit(X_train)
    X_scaler = post_scaler.transform(X_train)

    X_test = post_scaler.transform(pca.transform(robust.transform(X_test)))

    if clip is not None:
        X_train = np.clip(X_train, -clip, clip)
        X_test = np.clip(X_test, -clip, clip)

    return X_train, X_test, y_train, y_test

def generate_diab():
    X_train, X_test, y_train, y_test  = _load_diabetes(clip=2.2)
    N, d = X_train.shape[0], X_train.shape[1]
    R = np.linalg.norm(X_train, axis=1).max()
    print("Data norm", R)

    fit = LogisticRegression(fit_intercept=False, random_state=1234).fit(X_train,y_train)
    print("classes, coeffs", fit.classes_, fit.coef_)
    y_scores_non_private = fit.predict_proba(X_test)[:,1]
    print("log loss", log_loss(y_test, y_scores_non_private))
    l2_norm = np.linalg.norm(fit.coef_)
    print("l2 norm of theta", l2_norm)

    frac_positive_in_ytrain = (y_train == 1).sum() / len(y_train)
    frac_positive_in_ytest = (y_test == 1).sum() / len(y_test)
    frac_positive_in_ytrain, frac_positive_in_ytest

    random_guess_loss = -(frac_positive_in_ytest * np.log(frac_positive_in_ytrain) +
                       (1-frac_positive_in_ytest) * np.log(1-frac_positive_in_ytrain))
    print("\n\n fraction positive", frac_positive_in_ytrain, frac_positive_in_ytest)
    print("Random guess naive log-loss", random_guess_loss)

    GS = np.sqrt(0.5+2.0*R**2+2.0*R**4) # sensitivity


    nP = 3  # number of parties

    # privacy setting for each party
    epsilons = np.array([0.2,0.1,0.2])
    DP_method = 'Gaussian'

    X_1, X_2, y_1, y_2 = train_test_split(X_train, y_train, test_size=0.5, random_state=28)
    X_2, X_3, y_2, y_3 = train_test_split(X_2, y_2, test_size=0.4, random_state=27)

    pass_object = PASS(d)
    N1, approx_ss1, perturbed_ss1, variances1 =  generate_approx_ss_noisy(X_1, y_1,
                                                               epsilons[0], GS, pass_object, seed=245)
    N3, approx_ss3, perturbed_ss3, variances3 =  generate_approx_ss_noisy(X_3, y_3,
                                                                   epsilons[2], GS, pass_object, seed=178)
    N2, approx_ss2, perturbed_ss2, variances2 =  generate_approx_ss_noisy(X_2, y_2,
                                                                   epsilons[1], GS, pass_object, seed=396)


    approx_ss = np.stack((approx_ss1, approx_ss2, approx_ss3))
    perturbed_ss = np.stack((perturbed_ss1, perturbed_ss2, perturbed_ss3))
    variances = np.stack((variances1, variances2, variances3))
    N = np.array([N1, N2, N3])

    multiX = [X_1, X_2, X_3]
    multiY = [y_1, y_2, y_3]

    settings = nP, N, epsilons, DP_method, d, multiX, multiY, R, GS, l2_norm, pass_object, approx_ss, perturbed_ss, variances

    with open("diab.pkl", "wb") as f:
        dill.dump(settings, f)
        dill.dump((X_train, X_test, y_train, y_test), f)



if __name__ == '__main__':
    generate_cali()
