from models.dp_cgan import *
import os


def fit_DPGAN(
    df_,
    n_s,
    Z_DIM,
    model_id,
    epsilon,
    delta,
    seed,
    wandb,
    NOISE_MULT=1.15,
    NORM_CLIP=1.1,
    reload=False,
):

    df_train = df_.copy().values
    num_features = df_train.shape[1] - 1
    num_classes = np.unique(df_.values[:, -1]).size
    if 2 < np.unique(df_.values[:, -1]).size < 11:
        dummies = pd.get_dummies(df_train[:, -1], prefix="t_")
        df_train = np.concatenate([df_train[:, :-1], dummies.values], axis=1)

    # region parameters
    EPOCHS = 200
    BATCH_SIZE = min(600, len(df_))
    N_GEN = n_s

    feature_dim = df_train.shape[1]
    Z_DIM = min(feature_dim, 100) if Z_DIM is None or Z_DIM == "None" else Z_DIM

    BUFFER_SIZE = len(df_train)  # Total size of training data

    DP_DELTA = (
        1 / BUFFER_SIZE if not delta else delta
    )  # Needs to be smaller than 1/BUFFER_SIZE

    if epsilon < 20:

        eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
            n=BUFFER_SIZE,
            batch_size=BATCH_SIZE,
            noise_multiplier=NOISE_MULT,
            epochs=EPOCHS,
            delta=DP_DELTA,
        )

        while eps > epsilon + 1e-10:
            if BATCH_SIZE > 10:
                BATCH_SIZE = 10
            elif EPOCHS > 50:
                EPOCHS = min(EPOCHS - 50, 50)
            else:
                NOISE_MULT += 0.1
            eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
                n=BUFFER_SIZE,
                batch_size=BATCH_SIZE,
                noise_multiplier=NOISE_MULT,
                epochs=EPOCHS,
                delta=DP_DELTA,
            )

        while eps < epsilon + 1e-10:
            NOISE_MULT -= 0.01
            eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
                n=BUFFER_SIZE,
                batch_size=BATCH_SIZE,
                noise_multiplier=NOISE_MULT,
                epochs=EPOCHS,
                delta=DP_DELTA,
            )
        NOISE_MULT += 0.01
    else:
        eps = epsilon

    NR_MICROBATCHES = (
        BATCH_SIZE  # Each batch of data is split in smaller units called microbatches.
    )

    configs = {
        "epsilon": eps,
        "delta": DP_DELTA,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "noise_multiplier": NOISE_MULT,
        "norm_clip": NORM_CLIP,
        "z_dim": Z_DIM,
        "n_gen": N_GEN,
        "feature_dim": feature_dim,
    }
    wandb.config.update({f"dpcgan_{k}": v for k, v in configs.items()})

    N_DISC = 1  # Number of times we train DISC before training GEN once

    LR_DISC = tf.compat.v1.train.polynomial_decay(
        learning_rate=0.150,
        global_step=tf.compat.v1.train.get_or_create_global_step(),
        decay_steps=10000,
        end_learning_rate=0.052,
        power=1,
    )

    # endregion

    # region logging
    results_dir = f"output/{model_id}"
    checkpoint_dir = f"{results_dir}/training_checkpoints"

    def checkpoint_name(title):
        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt__" + str(title))
        return checkpoint_prefix

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if feature_dim > 740:
        import matplotlib.pyplot as plt

        images_dir = results_dir + "/images"
        if not os.path.exists(images_dir):
            os.makedirs(images_dir)

        def generate_and_save_images(title, model, epoch, test_input):
            # Notice `training` is set to False: This is so all layers run in inference mode (batchnorm).
            predictions = model([test_input], training=False)

            fig = plt.figure(figsize=(2, 10))

            for i in range(predictions.shape[0]):
                plt.subplot(10, 1, i + 1)
                plt.imshow(
                    predictions.numpy()[:, :-10].reshape((10, 28, 28, 1))[i, :, :, 0]
                    * 127.5
                    + 127.5,
                    cmap="gray",
                )
                plt.axis("off")

            plt.savefig(
                images_dir + "/" + title + "___image_at_epoch_{:04d}.png".format(epoch)
            )
            plt.show()

        seed_latent = tf.random.normal([10, Z_DIM])

    # endregion

    GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
    DPGradientDescentGaussianOptimizer_NEW = make_gaussian_optimizer_class(
        GradientDescentOptimizer
    )

    def make_generator_model_FCC():

        # INPUT: image generator input
        in_lat = layers.Input(shape=(Z_DIM,))

        ge1 = layers.Dense(Z_DIM, use_bias=True)(in_lat)
        ge1 = layers.ReLU()(ge1)

        # out_layer = layers.Dense(feature_dim, use_bias=True, activation="tanh")(ge1)
        out_layer = layers.Dense(feature_dim, use_bias=True)(ge1)

        # features_pre = tf.slice(out_layer, begin=[0, 0], size=[?, num_features])
        # targets_pre = tf.slice(out_layer, begin=[0, num_features], size=[?, num_classes])

        features_pre = out_layer[:, :num_features]
        targets_pre = out_layer[:, num_features:]

        features_post = tf.keras.activations.tanh(features_pre)
        targets_post = tf.keras.activations.sigmoid(targets_pre)

        merge = layers.concatenate([features_post, targets_post], axis=1)

        model = models.Model([in_lat], merge)

        return model

    def make_discriminator_model_FCC():

        # INPUT: Image
        in_image = layers.Input(shape=(feature_dim))
        in_image_b = layers.Flatten()(in_image)

        # MERGE
        # merge = layers.concatenate([in_image_b], axis=1)

        ge1 = layers.Dense(Z_DIM, use_bias=True)(in_image_b)
        ge1 = layers.ReLU()(ge1)

        out_layer = layers.Dense(1, use_bias=True)(ge1)

        model = models.Model([in_image], out_layer)

        return model

    cross_entropy_DISC = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.losses.Reduction.NONE
    )
    cross_entropy_GEN = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    # Notice the use of `tf.function`: This annotation causes the function to be "compiled".
    @tf.function
    def train_step_DISC(images, noise):
        with tf.GradientTape(persistent=True) as disc_tape_real:
            # This dummy call is needed to obtain the var list.
            dummy = discriminator([images], training=True)
            var_list = discriminator.trainable_variables

            # In Eager mode, the optimizer takes a function that returns the loss.
            def loss_fn_real():
                real_output = discriminator([images], training=True)
                disc_real_loss = cross_entropy_DISC(
                    tf.ones_like(real_output), real_output
                )
                return disc_real_loss

            grads_and_vars_real = discriminator_optimizer.compute_gradients(
                loss_fn_real,
                var_list,
                gradient_tape=disc_tape_real,
                curr_noise_mult=NOISE_MULT,
                curr_norm_clip=NORM_CLIP,
            )

            # In Eager mode, the optimizer takes a function that returns the loss.
            def loss_fn_fake():
                generated_images = generator([noise], training=True)
                fake_output = discriminator([generated_images], training=True)
                disc_fake_loss = cross_entropy_DISC(
                    tf.zeros_like(fake_output), fake_output
                )
                return disc_fake_loss

            grads_and_vars_fake = discriminator_optimizer.compute_gradients(
                loss_fn_fake,
                var_list,
                gradient_tape=disc_tape_real,
                curr_noise_mult=0,
                curr_norm_clip=NORM_CLIP,
            )
            disc_loss_r = loss_fn_real()
            disc_loss_f = loss_fn_fake()

            s_grads_and_vars = [
                (grads_and_vars_real[idx] + grads_and_vars_fake[idx])
                for idx in range(len(grads_and_vars_real))
            ]
            sanitized_grads_and_vars = list(zip(s_grads_and_vars, var_list))

            discriminator_optimizer.apply_gradients(sanitized_grads_and_vars)

        return (disc_loss_r, disc_loss_f)

    # Notice the use of `tf.function`: This annotation causes the function to be "compiled".
    @tf.function
    def train_step_GEN(noise):
        with tf.GradientTape() as gen_tape:
            generated_images = generator([noise], training=True)
            fake_output = discriminator([generated_images], training=True)
            # if the generator is performing well, the discriminator will classify the fake images as real (or 1)
            gen_loss = cross_entropy_GEN(tf.ones_like(fake_output), fake_output)

        gradients_of_generator = gen_tape.gradient(
            gen_loss, generator.trainable_variables
        )
        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.trainable_variables)
        )

        return gen_loss

    def train(dataset, title, verbose):
        for epoch in range(EPOCHS):
            start = time.time()

            i_gen = 0
            for image_batch in dataset:
                if verbose:
                    print("Iteration: " + str(i_gen + 1))

                noise = tf.random.normal([BATCH_SIZE, Z_DIM])

                d_loss_r, d_loss_f = train_step_DISC(image_batch, noise)
                if verbose:
                    print("Loss DISC Real: " + str(tf.reduce_mean(d_loss_r)))
                    print("Loss DISC Fake: " + str(tf.reduce_mean(d_loss_f)))
                wandb.log({"Loss DISC Real": tf.reduce_mean(d_loss_r)})
                wandb.log({"Loss DISC Fake": tf.reduce_mean(d_loss_f)})

                if (i_gen + 1) % N_DISC == 0:
                    g_loss_f = train_step_GEN(noise)
                    if verbose:
                        print("Loss GEN Fake:: " + str(g_loss_f))
                    wandb.log({"Loss GEN Fake": g_loss_f})

                i_gen = i_gen + 1

            print("Time for epoch {} is {} sec".format(epoch + 1, time.time() - start))

            if feature_dim > 740:
                generate_and_save_images(title, generator, epoch + 1, seed_latent)

            # Save the model
            checkpoint.save(
                file_prefix=checkpoint_name(title + "__epoch=" + str(epoch) + "__")
            )

    generator_optimizer = tf.keras.optimizers.Adam()

    discriminator_optimizer = DPGradientDescentGaussianOptimizer_NEW(
        learning_rate=LR_DISC,
        l2_norm_clip=NORM_CLIP,
        noise_multiplier=NOISE_MULT,
        num_microbatches=NR_MICROBATCHES,
    )

    # Create/reinitiate models
    generator = make_generator_model_FCC()
    discriminator = make_discriminator_model_FCC()

    checkpoint = tf.train.Checkpoint(
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer,
        generator=generator,
        discriminator=discriminator,
    )

    tf.random.set_seed(seed)

    # Batch and random shuffle training data
    train_dataset = (
        tf.data.Dataset.from_tensor_slices((df_train))
        .shuffle(BUFFER_SIZE)
        .batch(BATCH_SIZE, drop_remainder=True)
    )

    # GIVES CURRENT TRIAL A NAME - Suggestion: from parameters used
    training_title = f"eps{eps:.2f}"

    checkpoint_name_str = (
        checkpoint_dir
        + "/ckpt__"
        + str(training_title)
        + "__epoch="
        + str(EPOCHS - 1)
        + "__-"
        + str(EPOCHS)
    )

    wandb.config.update(
        {"checkpoint_name_str": os.getcwd() + "/" + checkpoint_name_str}
    )
    # print("##################################################")
    # print("path exists ", os.path.exists(checkpoint_name_str + ".index"))
    # print("reload ", reload)
    if not os.path.exists(checkpoint_name_str + ".index") or reload:
        train(train_dataset, training_title, False)
    else:
        checkpoint.restore(checkpoint_name_str)

    tf.random.set_seed(seed)
    noise_GEN = tf.random.normal([N_GEN, Z_DIM])

    images_GEN = generator([noise_GEN], training=False)
    images_flat = layers.Flatten()(images_GEN)

    log_iw = discriminator([images_GEN], training=False)

    synth_data = images_flat.numpy()

    if np.unique(df_.values[:, -1]).size == 2:
        synth_data[:, -1] = (synth_data[:, -1] > 0) * 1.0

    elif 2 < np.unique(df_.values[:, -1]).size < 11:
        dummies = synth_data[:, -np.unique(df_.values[:, -1]).size :].argmax(1)[:, None]

        synth_data = np.concatenate(
            [synth_data[:, : -np.unique(df_.values[:, -1]).size], dummies], axis=1
        )
    synth_data = pd.DataFrame(synth_data)

    return synth_data, log_iw
