import numpy as np
import tensorflow as tf

physical_devices = tf.config.list_physical_devices("GPU")
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

import sklearn
from sklearn.linear_model import LogisticRegression
from tensorflow.keras import layers, models
import tensorflow_privacy
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy
from sklearn.metrics import roc_auc_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

# from dp_optimizer import make_gaussian_optimizer_class
from utils import checkpoint_name

import math
import numpy as np
import argparse
import numpy as np

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# from opacus import PrivacyEngine
# from torchvision import datasets, transforms
from tqdm import tqdm
from tmp_utils.GaussianCalibrator import calibrateAnalyticGaussianMechanism
import math

# from tmp_utils.poisson_sampler import poisson_sampler
# from tmp_utils.mu_search import mu0_search, cal_step_decay_rate
# from scipy.stats import norm
# from scipy import optimize
from utils import ObjectView
import utils

# import tf_privacy


# import time

import os

from optimize_thr import optimize_easy

import pandas as pd


def compute_weights(df_train, synth_data, data):
    assert len(data) == len(df_train) + len(synth_data)
    weights = (
        np.concatenate(
            [
                np.ones(len(df_train)) / len(df_train),
                np.ones(len(synth_data)) / len(synth_data),
            ],
            0,
        )
        / 2.0
        * len(data)
    )
    assert (sum(weights) > len(data) - 1e-3) and (sum(weights) < len(data) + 1e-3)
    return weights


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def accuracy(logits, target, weights=None, logit=True):
    assert (
        (weights is None)
        or (sum(weights) > len(logits) - 1e-3)
        and (sum(weights) < len(logits) + 1e-3)
    )
    if logit:
        logits = sigmoid(logits)
    n = len(logits)
    if weights is None:
        return 1 - sum(((logits.squeeze() > 0.5) - target) ** 2) / n
    return 1 - sum(weights * ((logits.squeeze() > 0.5) - target) ** 2) / n


def dp_class_adam_mlp(
    df_train, df_test, synth_data, epsilon, delta, model_id, reload, seed, private=True
):

    NORM_CLIP = 1.1  # Does NOT affect EPSILON, but increases NOISE on gradients
    BATCH_SIZE = 1000
    NOISE_MULT = 1.15
    LEARN_R = 0.01
    EPOCHS = 5
    BUFFER_SIZE = len(df_train)  # Total size of training data
    DP_DELTA = (
        1 / BUFFER_SIZE if not delta else delta / BUFFER_SIZE
    )  # Needs to be smaller than 1/BUFFER_SIZE

    BATCH_SIZE = 10
    # Select your differentially private optimizer
    num_microbatches = BUFFER_SIZE // BATCH_SIZE

    eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
        n=BUFFER_SIZE,
        batch_size=BATCH_SIZE,
        noise_multiplier=NOISE_MULT * 2,
        epochs=EPOCHS,
        delta=DP_DELTA,
    )

    while eps > epsilon + 1e-10:
        if EPOCHS > 1:
            EPOCHS = 1
        elif BATCH_SIZE > 10:
            BATCH_SIZE = 10
        else:
            NOISE_MULT += 0.1
            if NORM_CLIP > 0.5:
                NORM_CLIP -= 0.05
        eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
            n=BUFFER_SIZE,
            batch_size=BATCH_SIZE,
            noise_multiplier=NOISE_MULT * 2,
            epochs=EPOCHS,
            delta=DP_DELTA,
        )

    while eps < epsilon:
        NOISE_MULT -= 0.1

        eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
            n=BUFFER_SIZE,
            batch_size=BATCH_SIZE,
            noise_multiplier=NOISE_MULT * 2,
            epochs=EPOCHS,
            delta=DP_DELTA,
        )

    if eps > epsilon:
        NOISE_MULT += 0.1

    import tensorflow as tf

    synth_data_ = synth_data.values[: len(df_train)]
    df_train = df_train.values[: len(synth_data_)]

    data = np.concatenate([df_train, synth_data_], 0)
    labels = np.concatenate([np.ones(len(df_train)), np.zeros(len(synth_data_))], 0)
    labels = tf.keras.utils.to_categorical(labels, 2)
    # weights = compute_weights(df_train, synth_data_, data)

    model = tf.keras.Sequential(
        [
            # tf.keras.layers.Flatten(input_shape=(df_train.shape[1],)),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dense(2),
        ]
    )

    optimizer = tensorflow_privacy.DPAdamGaussianOptimizer(
        l2_norm_clip=NORM_CLIP,
        noise_multiplier=NOISE_MULT,
        num_microbatches=num_microbatches,
        learning_rate=LEARN_R,
    )

    # GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
    # DPGradientDescentGaussianOptimizer_NEW = make_gaussian_optimizer_class(
    #     GradientDescentOptimizer
    # )

    # optimizer = DPGradientDescentGaussianOptimizer_NEW(
    #     learning_rate=LEARN_R,
    #     l2_norm_clip=NORM_CLIP,
    #     noise_multiplier=NOISE_MULT,
    #     num_microbatches=num_microbatches,
    # )

    # Select your loss function
    loss = tf.keras.losses.CategoricalCrossentropy(
        from_logits=True, reduction=tf.losses.Reduction.NONE
    )

    # Compile your model
    model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])

    # Fit your model
    model.fit(
        data,
        labels,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
    )
    try:
        predictions = model.predict(df_test.values)
    except:
        predictions = model.predict(df_test)
    print("true pred: ", predictions[:, 1].mean())
    predictions = model.predict(synth_data_)
    print("synth pred: ", predictions[:, 1].mean())

    return predictions[:, 1]


def class_mlp(
    df_train, df_test, synth_data, epsilon, delta, model_id, reload, seed, private=True
):

    synth_data_ = synth_data.values[: len(df_train)]
    df_train = df_train.values[: len(synth_data_)]

    data = np.concatenate([df_train, synth_data_], 0)
    labels = np.concatenate([np.ones(len(df_train)), np.zeros(len(synth_data_))], 0)
    # labels = tf.keras.utils.to_categorical(labels, 2)
    # weights = compute_weights(df_train, synth_data_, data)

    #! SGD makes performance so much worse
    mlp = sklearn.neural_network.MLPClassifier(
        hidden_layer_sizes=(100,),
        activation="relu",
        # solver="sgd",
        # momentum=0,
    )

    mlp.fit(data, labels)

    # region tf

    # EPOCHS = 20

    # BUFFER_SIZE = len(df_train)  # Total size of training data
    # BATCH_SIZE = min(600, BUFFER_SIZE)

    # # Learning Rate for DISCRIMINATOR
    # LR_DISC = 0.01

    # checkpoint_dir = f"output/{model_id}_dpmlp/training_checkpoints"
    # if not os.path.exists(checkpoint_dir):
    #     os.makedirs(checkpoint_dir)

    # def make_discriminator_model_FCC():

    #     # INPUT: Image
    #     in_image = layers.Input(shape=(df_train.shape[1]))

    #     ge1 = layers.Dense(df_train.shape[1], use_bias=True)(in_image)
    #     ge1 = layers.ReLU()(ge1)
    #     ge1 = layers.Dense(df_train.shape[1], use_bias=True)(ge1)
    #     ge1 = layers.ReLU()(ge1)

    #     out_layer = layers.Dense(2, use_bias=True)(ge1)

    #     model = models.Model([in_image], out_layer)

    #     return model

    # GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer

    # discriminator = make_discriminator_model_FCC()
    # cross_entropy_DISC = tf.keras.losses.BinaryCrossentropy(
    #     from_logits=True, reduction=tf.losses.Reduction.NONE
    # )

    # @tf.function
    # def train_step_DISC(images, labels):
    #     # This dummy call is needed to obtain the var list.
    #     dummy = discriminator([images], training=True)
    #     var_list = discriminator.trainable_variables

    #     # In Eager mode, the optimizer takes a function that returns the loss.
    #     def loss_fn_real():
    #         real_output = discriminator([images], training=True)
    #         disc_real_loss = cross_entropy_DISC(labels, real_output)
    #         return disc_real_loss

    #     grads_and_vars_real = discriminator_optimizer.compute_gradients(
    #         loss_fn_real,
    #         var_list,
    #     )

    #     # In Eager mode, the optimizer takes a function that returns the loss.
    #     disc_loss_r = loss_fn_real()

    #     s_grads_and_vars = grads_and_vars_real
    #     sanitized_grads_and_vars = list(zip(s_grads_and_vars, var_list))

    #     discriminator_optimizer.apply_gradients(sanitized_grads_and_vars)

    #     return disc_loss_r

    # def train(dataset, title, verbose):
    #     for epoch in range(EPOCHS):
    #         # start = time.time()

    #         i_gen = 0
    #         for image_batch, label_batch in dataset:
    #             if verbose:
    #                 print("Iteration: " + str(i_gen + 1))

    #             d_loss_r = train_step_DISC(image_batch, label_batch)
    #             if verbose:
    #                 print("Loss DISC Real: " + str(tf.reduce_mean(d_loss_r)))

    #             i_gen = i_gen + 1

    #         # print("Time for epoch {} is {} sec".format(epoch + 1, time.time() - start))

    #         # Save the model
    #         checkpoint.save(
    #             file_prefix=checkpoint_name(
    #                 title + "__epoch=" + str(epoch) + "__", checkpoint_dir
    #             )
    #         )

    # discriminator_optimizer = GradientDescentOptimizer(
    #     learning_rate=LR_DISC,
    # )

    # # Create/reinitiate models
    # discriminator = make_discriminator_model_FCC()

    # # Create checkpoint structure
    # checkpoint = tf.train.Checkpoint(
    #     discriminator_optimizer=discriminator_optimizer, discriminator=discriminator
    # )

    # # Create/reinitiate models
    # discriminator = make_discriminator_model_FCC()

    # tf.random.set_seed(seed)

    # # GIVES CURRENT TRIAL A NAME - Suggestion: from parameters used
    # training_title = f"non_dp"

    # model_pth = (
    #     checkpoint_dir
    #     + "/ckpt__"
    #     + str(training_title)
    #     + "__epoch="
    #     + str(EPOCHS - 1)
    #     + "__-"
    #     + str(EPOCHS)
    # )

    # if not os.path.exists(model_pth) or reload:
    #     dataset = (
    #         tf.data.Dataset.from_tensor_slices((data, labels))
    #         .shuffle(BUFFER_SIZE)
    #         .batch(BATCH_SIZE, drop_remainder=True)
    #     )

    #     train(dataset, training_title, False)

    # else:
    #     checkpoint.restore(model_pth)

    # logits = discriminator([data], training=False)[:, -1]
    # endregion

    log_proba = mlp.predict_log_proba(synth_data)
    print(f"Accuracy of nonDP MLP discr {mlp.score(data, labels)}")

    log_weights = log_proba[:, 1] - log_proba[:, 0]
    # logits_test_df = discriminator([df_test.values], training=False)[:, -1]

    return log_weights, None


# def dp_class_mlp(
#     df_train, df_test, synth_data, epsilon, delta, model_id, reload, seed, private=True
# ):

#     synth_data_ = synth_data.values[: len(df_train)]
#     df_train = df_train.values[: len(synth_data_)]

#     data = np.concatenate([df_train, synth_data_], 0)
#     labels = np.concatenate([np.ones(len(df_train)), np.zeros(len(synth_data_))], 0)
#     labels = tf.keras.utils.to_categorical(labels, 2)
#     # weights = compute_weights(df_train, synth_data_, data)

#     BUFFER_SIZE = len(df_train)  # Total size of training data
#     NORM_CLIP = 1.1  # Does NOT affect EPSILON, but increases NOISE on gradients

#     NOISE_MULT = 1.15

#     DP_DELTA = (
#         1 / BUFFER_SIZE if not delta else delta / BUFFER_SIZE
#     )  # Needs to be smaller than 1/BUFFER_SIZE
#     EPOCHS = 20

#     BATCH_SIZE = min(600, BUFFER_SIZE)

#     eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
#         n=BUFFER_SIZE,
#         batch_size=BATCH_SIZE,
#         noise_multiplier=NOISE_MULT,
#         epochs=EPOCHS,
#         delta=DP_DELTA,
#     )

#     while eps > epsilon + 1e-10:
#         if EPOCHS > 1:
#             EPOCHS = 1
#         elif BATCH_SIZE > 60:
#             BATCH_SIZE -= 10
#         else:
#             NOISE_MULT += 0.1
#         eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
#             n=BUFFER_SIZE,
#             batch_size=BATCH_SIZE,
#             noise_multiplier=NOISE_MULT,
#             epochs=EPOCHS,
#             delta=DP_DELTA,
#         )

#     while eps < epsilon:
#         if BATCH_SIZE < BUFFER_SIZE:
#             BATCH_SIZE += 10
#         else:
#             NOISE_MULT -= 0.1

#         eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
#             n=BUFFER_SIZE,
#             batch_size=BATCH_SIZE,
#             noise_multiplier=NOISE_MULT,
#             epochs=EPOCHS,
#             delta=DP_DELTA,
#         )

#     NR_MICROBATCHES = (
#         BATCH_SIZE  # Each batch of data is split in smaller units called microbatches.
#     )

#     # Learning Rate for DISCRIMINATOR
#     LR_DISC = tf.compat.v1.train.polynomial_decay(
#         learning_rate=0.150,
#         global_step=tf.compat.v1.train.get_or_create_global_step(),
#         decay_steps=10000,
#         end_learning_rate=0.052,
#         power=1,
#     )

#     checkpoint_dir = f"output/{model_id}_dpmlp/training_checkpoints"
#     if not os.path.exists(checkpoint_dir):
#         os.makedirs(checkpoint_dir)

#     def make_discriminator_model_FCC():

#         # INPUT: Image
#         in_image = layers.Input(shape=(df_train.shape[1]))

#         ge1 = layers.Dense(df_train.shape[1], use_bias=True)(in_image)
#         ge1 = layers.ReLU()(ge1)
#         ge1 = layers.Dense(df_train.shape[1], use_bias=True)(ge1)
#         ge1 = layers.ReLU()(ge1)

#         out_layer = layers.Dense(2, use_bias=True)(ge1)

#         model = models.Model([in_image], out_layer)

#         return model

#     GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
#     DPGradientDescentGaussianOptimizer_NEW = make_gaussian_optimizer_class(
#         GradientDescentOptimizer
#     )

#     discriminator = make_discriminator_model_FCC()
#     cross_entropy_DISC = tf.keras.losses.BinaryCrossentropy(
#         from_logits=True, reduction=tf.losses.Reduction.NONE
#     )

#     @tf.function
#     def train_step_DISC(images, labels):
#         with tf.GradientTape(persistent=True) as disc_tape_real:
#             # This dummy call is needed to obtain the var list.
#             dummy = discriminator([images], training=True)
#             var_list = discriminator.trainable_variables

#             # In Eager mode, the optimizer takes a function that returns the loss.
#             def loss_fn_real():
#                 real_output = discriminator([images], training=True)
#                 disc_real_loss = cross_entropy_DISC(labels, real_output)
#                 return disc_real_loss

#             grads_and_vars_real = discriminator_optimizer.compute_gradients(
#                 loss_fn_real,
#                 var_list,
#                 gradient_tape=disc_tape_real,
#                 curr_noise_mult=NOISE_MULT,
#                 curr_norm_clip=NORM_CLIP,
#             )

#             # In Eager mode, the optimizer takes a function that returns the loss.
#             disc_loss_r = loss_fn_real()

#             s_grads_and_vars = grads_and_vars_real
#             sanitized_grads_and_vars = list(zip(s_grads_and_vars, var_list))

#             discriminator_optimizer.apply_gradients(sanitized_grads_and_vars)

#         return disc_loss_r

#     def train(dataset, title, verbose):
#         for epoch in range(EPOCHS):
#             # start = time.time()

#             i_gen = 0
#             for image_batch, label_batch in dataset:
#                 if verbose:
#                     print("Iteration: " + str(i_gen + 1))

#                 d_loss_r = train_step_DISC(image_batch, label_batch)
#                 if verbose:
#                     print("Loss DISC Real: " + str(tf.reduce_mean(d_loss_r)))

#                 i_gen = i_gen + 1

#             # print("Time for epoch {} is {} sec".format(epoch + 1, time.time() - start))

#             # Save the model
#             checkpoint.save(
#                 file_prefix=checkpoint_name(
#                     title + "__epoch=" + str(epoch) + "__", checkpoint_dir
#                 )
#             )

#     discriminator_optimizer = DPGradientDescentGaussianOptimizer_NEW(
#         learning_rate=LR_DISC,
#         l2_norm_clip=NORM_CLIP,
#         noise_multiplier=NOISE_MULT,
#         num_microbatches=NR_MICROBATCHES,
#     )

#     # Create/reinitiate models
#     discriminator = make_discriminator_model_FCC()

#     # Create checkpoint structure
#     checkpoint = tf.train.Checkpoint(
#         discriminator_optimizer=discriminator_optimizer, discriminator=discriminator
#     )

#     # Create/reinitiate models
#     discriminator = make_discriminator_model_FCC()

#     tf.random.set_seed(seed)

#     # GIVES CURRENT TRIAL A NAME - Suggestion: from parameters used
#     training_title = f"eps{eps:.2f}"

#     model_pth = (
#         checkpoint_dir
#         + "/ckpt__"
#         + str(training_title)
#         + "__epoch="
#         + str(EPOCHS - 1)
#         + "__-"
#         + str(EPOCHS)
#     )

#     if not os.path.exists(model_pth) or reload:
#         dataset = (
#             tf.data.Dataset.from_tensor_slices((data, labels))
#             .shuffle(BUFFER_SIZE)
#             .batch(BATCH_SIZE, drop_remainder=True)
#         )

#         train(dataset, training_title, False)

#     else:
#         checkpoint.restore(model_pth)

#     logits = discriminator([data], training=False)[:, -1]
#     print(f"Accuracy of DP MLP discr {accuracy(logits, labels[:, 1]):.2f}")
#     print(f"Auroc of DP MLP discr {roc_auc_score(labels[:, 1], sigmoid(logits)):.2f}")

#     log_weights = discriminator([df_train], training=False)[:, -1]
#     logits_test_df = discriminator([df_test.values], training=False)[:, -1]

#     return log_weights


def dp_class_logreg(
    df_train,
    df_test,
    synth_data_,
    epsilon,
    reg,
    n_d,
    unbiased=False,
    args=None,
    wandb=None,
):

    d = synth_data_.shape[1]

    # pca = PCA()
    # df_train = pca.fit_transform(df_train)

    synth_data = synth_data_.values[: len(df_train)]
    # synth_data = pca.transform(synth_data)

    data = np.concatenate([df_train, synth_data])
    labels = np.concatenate([np.ones(len(df_train)), np.zeros(len(synth_data))])

    discr = LogisticRegression(solver="lbfgs", penalty="l2", C=1 / reg, max_iter=4000)

    discr.fit(data, labels)
    print(f"Discr Logreg Accuracy is {discr.score(data, labels):.2f}")
    print(
        f"Discr Logreg Auroc is {roc_auc_score(labels, discr.predict_proba(data)[:, 1]):.2f}"
    )

    # discr.fit((data[:, 1]*(data[:, 2]+1))[:, None], labels)
    # print(f"Logreg Accuracy is {discr.score((data[:, 1]*(data[:, 2]+1))[:, None], labels):.2f}")

    log_isweights_synth = (synth_data_ @ discr.coef_[0] + discr.intercept_).values

    # privatise
    lapl_scale = 4 * np.sqrt(d) / (n_d * reg * epsilon)
    lapl_noice = np.random.laplace(loc=0, scale=lapl_scale, size=(d + 1))
    log_isweights_beta = (
        log_isweights_synth + lapl_noice[0] + synth_data @ lapl_noice[1:]
    )

    if unbiased:

        def weights_error(x, true=log_isweights_synth):
            return ((np.exp(true) - np.exp(x)) ** 2).mean()

        old_error = weights_error(log_isweights_beta)
        utils.log(
            args,
            "lapl_noised_beta",
            "beta_mse",
            old_error,
        )

        lapl_scale_vec = np.concatenate(
            (lapl_scale * synth_data_, np.ones((len(synth_data_), 1)) * lapl_scale),
            axis=1,
        )
        comp_mean = np.log(1 - lapl_scale_vec**2).sum(1)
        if (comp_mean != comp_mean).any():
            raise ValueError("NANs in beta bias")

        # thr = optimize_easy(
        #     data.shape[1], lapl_scale, log_isweights_beta, synth_data, comp_mean
        # )

        thr = -5
        # thr = -0.00206
        log_isweights_beta = (
            log_isweights_synth + lapl_noice[0] + synth_data @ lapl_noice[1:]
        )
        log_isweights_beta[comp_mean > thr] += comp_mean[comp_mean > thr]
        new_error = weights_error(log_isweights_beta)
        change = new_error - old_error
        if change > 0:
            print("thr Error increased by " + str(change))
        else:
            print("thr Error decreased by " + str(change))
        utils.log(
            args,
            f"thr{thr}_debiased",
            "beta_mse",
            new_error,
        )
        wandb.log({"beta_mse": old_error})
        wandb.log({"debiased_beta_mse": new_error})

        # log_isweights_beta[comp_mean < thr] += comp_mean[comp_mean < thr]
        # change = weights_error(log_isweights_beta) - old_error
        # if change > 0:
        #     print("Error increased by " + str(change))
        # else:
        #     print("Error decreased by " + str(change))
        utils.log(
            args,
            f"debiased",
            "beta_mse",
            new_error,
        )

    # pred_proba_test_beta_noised = discr.predict_proba(df_test)
    # log_isweights_beta_test = (
    #     np.log(np.clip(pred_proba_test_beta_noised, 1e-8, 1.0))[:, 1]
    #     - np.log(np.clip(pred_proba_test_beta_noised, 1e-8, 1.0))[:, 0]
    # )

    if (log_isweights_synth != log_isweights_synth).any():
        raise ("NAN weight Error")
    if (log_isweights_beta != log_isweights_beta).any():
        raise ("NAN weight Error")

    return (
        log_isweights_synth,
        log_isweights_beta,
    )


def opacus_class_mlp(
    df_train, df_test, synth_data, epsilon, delta, model_id, reload, seed, private=True
):
    def train(
        args,
        step,
        model,
        device,
        train_loader,
        optimizer,
        train_dataset=False,
        dp=False,
        sens_decay=False,
        mu_allocation=False,
        privacy_engine=None,
    ):
        model.train()
        criterion = nn.CrossEntropyLoss()
        losses = []
        correct = 0
        total = 0
        if dp == False:
            for _batch_idx, (data, target) in enumerate(tqdm(train_loader)):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
                step += 1
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.shape[0]
        else:
            if sens_decay:
                clip = args.max_per_sample_grad_norm * (args.decay_rate_sens) ** step
                privacy_engine.set_clip(clip)
            if mu_allocation:
                unit_sigma = 1 / (args.mu_0 / (args.decay_rate_mu ** (step)))
                privacy_engine.set_unit_sigma(unit_sigma)
            for i in tqdm(range(int(1 / args.sampling_rate))):
                data, target = poisson_sampler(train_dataset, args.sampling_rate)
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
                step += 1
                pred = output.argmax(dim=1, keepdim=True)

                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.shape[0]
        acc = 100.0 * correct / total
        print(
            "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
                np.mean(losses),
                correct,
                total,
                100.0 * correct / total,
            )
        )
        return step

    class SampleConvNet(nn.Module):
        def __init__(self, num_dim):
            super().__init__()
            self.fc1 = nn.Linear(num_dim, 100)
            self.fc4 = nn.Linear(100, 2)

        def forward(self, x):
            # x of shape [B, 1, 28, 28]
            x = F.gelu(self.fc1(x))  # -> [B, 32]
            # x = F.relu(self.fc2(x))  # -> [B, 32]
            # x = F.relu(self.fc3(x))  # -> [B, 32]
            x = self.fc4(x)  # -> [B, 10]
            return x

        def name(self):
            return "SampleConvNet"

    synth_data_ = synth_data.values[: len(df_train)]
    df_train = df_train.values[: len(synth_data_)]

    data = np.concatenate([df_train, synth_data_], 0)
    labels = np.concatenate(
        [np.ones(len(df_train), dtype=int), np.zeros(len(synth_data_), dtype=int)], 0
    )
    if private:
        dp = True
        decay_rate_sens = 0.3
        decay_rate_mu = 0.8
        sens_decay = True
        mu_allocation = True
        lr = 0.15  #!
        epochs = 20  #!
    else:
        dp = False
        decay_rate_sens = None
        decay_rate_mu = None
        sens_decay = False  #!
        mu_allocation = False  #!
        epsilon = None
        lr = 0.15
        epochs = 20

    num_data = len(data)
    batch_size = 256  # this is the expectated batch size since we use poisson
    sampling_rate = batch_size / num_data
    iteration = int(epochs / sampling_rate)
    if epsilon:
        clipping_value = 1.5
        delta = 1.0 / num_data
        mu = 1 / calibrateAnalyticGaussianMechanism(
            epsilon=epsilon, delta=delta, GS=1, tol=1.0e-12
        )
        mu_t = math.sqrt(math.log(mu**2 / (sampling_rate**2 * iteration) + 1))
        sigma = 1 / mu_t
        if mu_allocation:
            decay_rate_mu = cal_step_decay_rate(decay_rate_mu, iteration)
        if sens_decay:
            decay_rate_sens = cal_step_decay_rate(decay_rate_sens, iteration)
    parser_dict = {
        "batch_size": batch_size,
        "sampling_rate": sampling_rate,
        "decay_rate_sens": decay_rate_sens,
        "decay_rate_mu": decay_rate_mu,
        "batch_size": batch_size,
        "sens_decay": sens_decay,
        "eps": epsilon,
        "lr": lr,
        "epochs": epochs,
        "n_runs": 1,
        "device": "cpu",
        "save_model": False,
        "dp": dp,
    }
    if private:
        parser_dict.update(
            {
                "mu_t": mu_t,
                "sigma": sigma,
                "max_per_sample_grad_norm": clipping_value,
                "delta": delta,
            }
        )

    if decay_rate_mu:
        parser_dict["mu_0"] = mu0_search(
            mu, iteration, decay_rate_mu, sampling_rate, mu_t=mu_t
        )
    args = ObjectView(parser_dict)

    device = torch.device(args.device)
    kwargs = {"num_workers": 1, "pin_memory": True}

    train_dataset = torch.utils.data.TensorDataset(
        torch.from_numpy(data).float(), torch.from_numpy(labels)
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    step = 0
    model = SampleConvNet(data.shape[1]).to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
    privacy_engine = None
    if dp:
        privacy_engine = PrivacyEngine(
            model,
            sample_rate=sampling_rate,
            max_grad_norm=args.max_per_sample_grad_norm,
            noise_multiplier=sigma,
        )
        privacy_engine.attach(optimizer)
    for epoch in range(1, args.epochs + 1):
        step = train(
            args,
            step,
            model,
            device,
            train_loader,
            optimizer,
            train_dataset=train_dataset,
            dp=dp,
            sens_decay=sens_decay,
            mu_allocation=mu_allocation,
            privacy_engine=privacy_engine,
        )

    with torch.no_grad():
        model.eval()
        synth_data_labels = model(
            torch.from_numpy(synth_data.values).float().to(device)
        ).squeeze()[:, 1]
        true_data_labels = model(
            torch.from_numpy(df_train).float().to(device)
        ).squeeze()[:, 1]
        print("Synth data preds: ", torch.sigmoid(synth_data_labels).mean())
        print("True data preds: ", torch.sigmoid(true_data_labels).mean())

        return synth_data_labels.numpy()


from math import exp, sqrt
from scipy.special import erf


def calibrateAnalyticGaussianMechanism(epsilon, delta, GS, tol=1.0e-12):
    """Calibrate a Gaussian perturbation for differential privacy using the analytic Gaussian mechanism of [Balle and Wang, ICML'18]
    Arguments:
    epsilon : target epsilon (epsilon > 0)
    delta : target delta (0 < delta < 1)
    GS : upper bound on L2 global sensitivity (GS >= 0)
    tol : error tolerance for binary search (tol > 0)
    Output:
    sigma : standard deviation of Gaussian noise needed to achieve (epsilon,delta)-DP under global sensitivity GS
    """

    def Phi(t):
        return 0.5 * (1.0 + erf(float(t) / sqrt(2.0)))

    def caseA(epsilon, s):
        return Phi(sqrt(epsilon * s)) - exp(epsilon) * Phi(-sqrt(epsilon * (s + 2.0)))

    def caseB(epsilon, s):
        return Phi(-sqrt(epsilon * s)) - exp(epsilon) * Phi(-sqrt(epsilon * (s + 2.0)))

    def doubling_trick(predicate_stop, s_inf, s_sup):
        while not predicate_stop(s_sup):
            s_inf = s_sup
            s_sup = 2.0 * s_inf
        return s_inf, s_sup

    def binary_search(predicate_stop, predicate_left, s_inf, s_sup):
        s_mid = s_inf + (s_sup - s_inf) / 2.0
        while not predicate_stop(s_mid):
            if predicate_left(s_mid):
                s_sup = s_mid
            else:
                s_inf = s_mid
            s_mid = s_inf + (s_sup - s_inf) / 2.0
        return s_mid

    delta_thr = caseA(epsilon, 0.0)

    if delta == delta_thr:
        alpha = 1.0

    else:
        if delta > delta_thr:
            predicate_stop_DT = lambda s: caseA(epsilon, s) >= delta
            function_s_to_delta = lambda s: caseA(epsilon, s)
            predicate_left_BS = lambda s: function_s_to_delta(s) > delta
            function_s_to_alpha = lambda s: sqrt(1.0 + s / 2.0) - sqrt(s / 2.0)

        else:
            predicate_stop_DT = lambda s: caseB(epsilon, s) <= delta
            function_s_to_delta = lambda s: caseB(epsilon, s)
            predicate_left_BS = lambda s: function_s_to_delta(s) < delta
            function_s_to_alpha = lambda s: sqrt(1.0 + s / 2.0) + sqrt(s / 2.0)

        predicate_stop_BS = lambda s: abs(function_s_to_delta(s) - delta) <= tol

        s_inf, s_sup = doubling_trick(predicate_stop_DT, 0.0, 1.0)
        s_final = binary_search(predicate_stop_BS, predicate_left_BS, s_inf, s_sup)
        alpha = function_s_to_alpha(s_final)

    sigma = alpha * GS / sqrt(2.0 * epsilon)

    return sigma
