import numpy as np
import os
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow.keras.backend as K
import time

import pandas as pd

from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy
from tensorflow_privacy.privacy.optimizers.dp_optimizer import (
    DPGradientDescentGaussianOptimizer,
)

physical_devices = tf.config.list_physical_devices('GPU') 
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

# from dp_optimizer import make_gaussian_optimizer_class
# from utils import checkpoint_name


# Method obtained from https://stackoverflow.com/questions/41123879/numpy-random-choice-in-tensorflow
def _random_choice(inputs, n_samples):
    """
    With replacement.
    Params:
      inputs (Tensor): Shape [n_states, n_features]
      n_samples (int): The number of random samples to take.
    Returns:
      sampled_inputs (Tensor): Shape [n_samples, n_features]
    """
    # (1, n_states) since multinomial requires 2D logits.
    uniform_log_prob = tf.expand_dims(tf.zeros(tf.shape(inputs)[0]), 0)

    ind = tf.compat.v1.multinomial(uniform_log_prob, n_samples)
    ind = tf.squeeze(ind, 0, name="random_choice_ind")  # (n_samples,)

    return tf.gather(inputs, ind, name="random_choice")


# region redefine the optimizer

from absl import logging
import collections

from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query


def make_optimizer_class(cls):
    """Constructs a DP optimizer class from an existing one."""
    parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__
    child_code = cls.compute_gradients.__code__
    GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP  # pylint: disable=invalid-name
    if child_code is not parent_code:
        logging.warning(
            "WARNING: Calling make_optimizer_class() on class %s that overrides "
            "method compute_gradients(). Check to ensure that "
            "make_optimizer_class() does not interfere with overridden version.",
            cls.__name__,
        )

    class DPOptimizerClass(cls):
        """Differentially private subclass of given class cls."""

        _GlobalState = collections.namedtuple(
            "_GlobalState", ["l2_norm_clip", "stddev"]
        )

        def __init__(
            self,
            dp_sum_query,
            num_microbatches=None,
            unroll_microbatches=False,
            *args,  # pylint: disable=keyword-arg-before-vararg, g-doc-args
            **kwargs,
        ):
            """Initialize the DPOptimizerClass.

            Args:
                dp_sum_query: DPQuery object, specifying differential privacy
                mechanism to use.
                num_microbatches: How many microbatches into which the minibatch is
                split. If None, will default to the size of the minibatch, and
                per-example gradients will be computed.
                unroll_microbatches: If true, processes microbatches within a Python
                loop instead of a tf.while_loop. Can be used if using a tf.while_loop
                raises an exception.
            """
            super(DPOptimizerClass, self).__init__(*args, **kwargs)
            self._dp_sum_query = dp_sum_query
            self._num_microbatches = num_microbatches
            self._global_state = self._dp_sum_query.initial_global_state()
            # TODO(b/122613513): Set unroll_microbatches=True to avoid this bug.
            # Beware: When num_microbatches is large (>100), enabling this parameter
            # may cause an OOM error.
            self._unroll_microbatches = unroll_microbatches

        def compute_gradients(
            self,
            loss,
            var_list,
            gate_gradients=GATE_OP,
            aggregation_method=None,
            colocate_gradients_with_ops=False,
            grad_loss=None,
            gradient_tape=None,
            curr_noise_mult=0,
            curr_norm_clip=1,
        ):

            self._dp_sum_query = gaussian_query.GaussianSumQuery(
                curr_norm_clip, curr_norm_clip * curr_noise_mult
            )
            self._global_state = self._dp_sum_query.make_global_state(
                curr_norm_clip, curr_norm_clip * curr_noise_mult
            )

            # TF is running in Eager mode, check we received a vanilla tape.
            if not gradient_tape:
                raise ValueError("When in Eager mode, a tape needs to be passed.")

            vector_loss = loss()
            if self._num_microbatches is None:
                self._num_microbatches = tf.shape(input=vector_loss)[0]
            sample_state = self._dp_sum_query.initial_sample_state(var_list)
            microbatches_losses = tf.reshape(vector_loss, [self._num_microbatches, -1])
            sample_params = self._dp_sum_query.derive_sample_params(self._global_state)

            def process_microbatch(i, sample_state):
                """Process one microbatch (record) with privacy helper."""
                microbatch_loss = tf.reduce_mean(
                    input_tensor=tf.gather(microbatches_losses, [i])
                )
                grads = gradient_tape.gradient(microbatch_loss, var_list)
                sample_state = self._dp_sum_query.accumulate_record(
                    sample_params, sample_state, grads
                )
                return sample_state

            for idx in range(self._num_microbatches):
                sample_state = process_microbatch(idx, sample_state)

            if curr_noise_mult > 0:
                (grad_sums, self._global_state,) = self._dp_sum_query.get_noised_result(
                    sample_state, self._global_state
                )
            else:
                grad_sums = sample_state

            def normalize(v):
                return v / tf.cast(self._num_microbatches, tf.float32)

            final_grads = tf.nest.map_structure(normalize, grad_sums)
            grads_and_vars = final_grads  # list(zip(final_grads, var_list))

            return grads_and_vars

    return DPOptimizerClass


def make_gaussian_optimizer_class(cls):
    """Constructs a DP optimizer with Gaussian averaging of updates."""

    class DPGaussianOptimizerClass(make_optimizer_class(cls)):
        """DP subclass of given class cls using Gaussian averaging."""

        def __init__(
            self,
            l2_norm_clip,
            noise_multiplier,
            num_microbatches=None,
            ledger=None,
            unroll_microbatches=False,
            *args,  # pylint: disable=keyword-arg-before-vararg
            **kwargs,
        ):
            dp_sum_query = gaussian_query.GaussianSumQuery(
                l2_norm_clip, l2_norm_clip * noise_multiplier
            )

            if ledger:
                dp_sum_query = privacy_ledger.QueryWithLedger(
                    dp_sum_query, ledger=ledger
                )

            super(DPGaussianOptimizerClass, self).__init__(
                dp_sum_query, num_microbatches, unroll_microbatches, *args, **kwargs
            )

            @property
            def ledger(self):
                return self._dp_sum_query.ledger

    return DPGaussianOptimizerClass


# endregion


def DP_CGAN(df_, n_s, Z_DIM, model_id, epsilon, delta, seed, wandb, NOISE_MULT = 1.15, NORM_CLIP = 1.1, reload=False):

    # region parameters    
    EPOCHS = 200
    BATCH_SIZE = min(600, len(df_))

    N_GEN = n_s
    df_train = df_
    feature_dim = df_train.shape[1] - 1
    Z_DIM = min(feature_dim, 100) if Z_DIM is None or Z_DIM=="None" else Z_DIM

    COND_num_classes = df_train.iloc[
        :, -1
    ].nunique()  # Number of classes, set to 10 for MNIST dataset

    BUFFER_SIZE = len(df_train)  # Total size of training data

    DP_DELTA = (
        1 / BUFFER_SIZE if not delta else delta
    )  # Needs to be smaller than 1/BUFFER_SIZE

    if (epsilon < 20):

        eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
            n=BUFFER_SIZE,
            batch_size=BATCH_SIZE,
            noise_multiplier=NOISE_MULT,
            epochs=EPOCHS,
            delta=DP_DELTA,
        )

        while eps > epsilon + 1e-10:
            if BATCH_SIZE > 10:
                BATCH_SIZE = 10
            elif EPOCHS > 50:
                EPOCHS = min(EPOCHS - 50, 50)
            else:
                NOISE_MULT += 0.1
            eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
                n=BUFFER_SIZE,
                batch_size=BATCH_SIZE,
                noise_multiplier=NOISE_MULT,
                epochs=EPOCHS,
                delta=DP_DELTA,
            )

        while (eps < epsilon + 1e-10):
            NOISE_MULT -= 0.01
            eps, _ = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
                n=BUFFER_SIZE,
                batch_size=BATCH_SIZE,
                noise_multiplier=NOISE_MULT,
                epochs=EPOCHS,
                delta=DP_DELTA,
            )
        NOISE_MULT += 0.01
    else:
        eps = epsilon

    NR_MICROBATCHES = (
        BATCH_SIZE  # Each batch of data is split in smaller units called microbatches.
    )

    configs = {
        "epsilon": eps,
        "delta": DP_DELTA,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "noise_multiplier": NOISE_MULT,
        "norm_clip": NORM_CLIP,
        "z_dim": Z_DIM,
        "n_gen": N_GEN,
        "cond_num_classes": COND_num_classes,
        "feature_dim": feature_dim,
    }
    wandb.config.update({f"dpcgan_{k}": v for k, v in configs.items()})

    N_DISC = 1  # Number of times we train DISC before training GEN once

    LR_DISC = tf.compat.v1.train.polynomial_decay(
        learning_rate=0.150,
        global_step=tf.compat.v1.train.get_or_create_global_step(),
        decay_steps=10000,
        end_learning_rate=0.052,
        power=1,
    )

    # endregion

    # region logging
    results_dir = f"output/{model_id}"
    checkpoint_dir = f"{results_dir}/training_checkpoints"

    def checkpoint_name(title):
        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt__" + str(title))
        return checkpoint_prefix

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if feature_dim > 740:
        import matplotlib.pyplot as plt

        images_dir = results_dir + "/images"
        if not os.path.exists(images_dir):
            os.makedirs(images_dir)

        def generate_and_save_images(title, model, epoch, test_input, test_label):
            # Notice `training` is set to False: This is so all layers run in inference mode (batchnorm).
            predictions = model([test_input, test_label], training=False)

            fig = plt.figure(figsize=(2, 10))

            for i in range(predictions.shape[0]):
                plt.subplot(10, 1, i + 1)
                plt.imshow(
                    predictions.numpy().reshape((10, 28, 28, 1))[i, :, :, 0] * 127.5
                    + 127.5,
                    cmap="gray",
                )
                plt.axis("off")

            plt.savefig(
                images_dir + "/" + title + "___image_at_epoch_{:04d}.png".format(epoch)
            )
            plt.show()

        seed_latent = tf.random.normal([10, Z_DIM])
        seed_labels = tf.Variable(
            np.diag(np.full(10, 1)).reshape((10, 10)), dtype="float32"
        )

    # endregion

    GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
    DPGradientDescentGaussianOptimizer_NEW = make_gaussian_optimizer_class(
        GradientDescentOptimizer
    )

    labels_gen_vec = np.eye((COND_num_classes), dtype="float32")


    def make_generator_model_FCC():
        # INPUT: label input
        in_label = layers.Input(shape=(COND_num_classes,))

        # INPUT: image generator input
        in_lat = layers.Input(shape=(Z_DIM,))

        # MERGE
        merge = layers.concatenate([in_lat, in_label], axis=1)

        ge1 = layers.Dense(Z_DIM, use_bias=True)(merge)
        ge1 = layers.ReLU()(ge1)

        out_layer = layers.Dense(feature_dim, use_bias=True, activation="tanh")(ge1)

        model = models.Model([in_lat, in_label], out_layer)

        return model

    def make_discriminator_model_FCC():
        # INPUT: Label
        in_label = layers.Input(shape=(COND_num_classes,))

        # INPUT: Image
        in_image = layers.Input(shape=(feature_dim))
        in_image_b = layers.Flatten()(in_image)

        # MERGE
        merge = layers.concatenate([in_image_b, in_label], axis=1)

        ge1 = layers.Dense(Z_DIM, use_bias=True)(merge)
        ge1 = layers.ReLU()(ge1)

        out_layer = layers.Dense(1, use_bias=True)(ge1)

        model = models.Model([in_image, in_label], out_layer)

        return model

    cross_entropy_DISC = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.losses.Reduction.NONE
    )
    cross_entropy_GEN = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    # Notice the use of `tf.function`: This annotation causes the function to be "compiled".
    @tf.function
    def train_step_DISC(images, labels, noise, labels_to_gen):
        with tf.GradientTape(persistent=True) as disc_tape_real:
            # This dummy call is needed to obtain the var list.
            dummy = discriminator([images, labels], training=True)
            var_list = discriminator.trainable_variables

            # In Eager mode, the optimizer takes a function that returns the loss.
            def loss_fn_real():
                real_output = discriminator([images, labels], training=True)
                disc_real_loss = cross_entropy_DISC(
                    tf.ones_like(real_output), real_output
                )
                return disc_real_loss

            grads_and_vars_real = discriminator_optimizer.compute_gradients(
                loss_fn_real,
                var_list,
                gradient_tape=disc_tape_real,
                curr_noise_mult=NOISE_MULT,
                curr_norm_clip=NORM_CLIP,
            )

            # In Eager mode, the optimizer takes a function that returns the loss.
            def loss_fn_fake():
                generated_images = generator([noise, labels_to_gen], training=True)
                fake_output = discriminator(
                    [generated_images, labels_to_gen], training=True
                )
                disc_fake_loss = cross_entropy_DISC(
                    tf.zeros_like(fake_output), fake_output
                )
                return disc_fake_loss

            grads_and_vars_fake = discriminator_optimizer.compute_gradients(
                loss_fn_fake,
                var_list,
                gradient_tape=disc_tape_real,
                curr_noise_mult=0,
                curr_norm_clip=NORM_CLIP,
            )
            disc_loss_r = loss_fn_real()
            disc_loss_f = loss_fn_fake()

            s_grads_and_vars = [
                (grads_and_vars_real[idx] + grads_and_vars_fake[idx])
                for idx in range(len(grads_and_vars_real))
            ]
            sanitized_grads_and_vars = list(zip(s_grads_and_vars, var_list))

            discriminator_optimizer.apply_gradients(sanitized_grads_and_vars)

        return (disc_loss_r, disc_loss_f)

    # Notice the use of `tf.function`: This annotation causes the function to be "compiled".
    @tf.function
    def train_step_GEN(labels, noise):
        with tf.GradientTape() as gen_tape:
            generated_images = generator([noise, labels], training=True)
            fake_output = discriminator([generated_images, labels], training=True)
            # if the generator is performing well, the discriminator will classify the fake images as real (or 1)
            gen_loss = cross_entropy_GEN(tf.ones_like(fake_output), fake_output)

        gradients_of_generator = gen_tape.gradient(
            gen_loss, generator.trainable_variables
        )
        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.trainable_variables)
        )

        return gen_loss

    def train(dataset, title, verbose):
        for epoch in range(EPOCHS):
            start = time.time()

            i_gen = 0
            for image_batch, label_batch in dataset:
                if verbose:
                    print("Iteration: " + str(i_gen + 1))

                noise = tf.random.normal([BATCH_SIZE, Z_DIM])
                labels_to_gen = _random_choice(labels_gen_vec, BATCH_SIZE)

                d_loss_r, d_loss_f = train_step_DISC(
                    image_batch, label_batch, noise, labels_to_gen
                )
                if verbose:
                    print("Loss DISC Real: " + str(tf.reduce_mean(d_loss_r)))
                    print("Loss DISC Fake: " + str(tf.reduce_mean(d_loss_f)))
                wandb.log({"Loss DISC Real": tf.reduce_mean(d_loss_r)})
                wandb.log({"Loss DISC Fake": tf.reduce_mean(d_loss_f)})

                if (i_gen + 1) % N_DISC == 0:
                    g_loss_f = train_step_GEN(labels_to_gen, noise)
                    if verbose:
                        print("Loss GEN Fake:: " + str(g_loss_f))
                    wandb.log({"Loss GEN Fake": g_loss_f})

                i_gen = i_gen + 1

            print("Time for epoch {} is {} sec".format(epoch + 1, time.time() - start))
            
            if feature_dim > 740:
                generate_and_save_images(
                title, generator, epoch + 1, seed_latent, seed_labels
            )

            # Save the model
            checkpoint.save(
                file_prefix=checkpoint_name(title + "__epoch=" + str(epoch) + "__")
            )

    generator_optimizer = tf.keras.optimizers.Adam()

    discriminator_optimizer = DPGradientDescentGaussianOptimizer_NEW(
        learning_rate=LR_DISC,
        l2_norm_clip=NORM_CLIP,
        noise_multiplier=NOISE_MULT,
        num_microbatches=NR_MICROBATCHES,
    )

    # Create/reinitiate models
    generator = make_generator_model_FCC()
    discriminator = make_discriminator_model_FCC()

    checkpoint = tf.train.Checkpoint(
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer,
        generator=generator,
        discriminator=discriminator,
    )

    tf.random.set_seed(seed)

    # Batch and random shuffle training data
    labels_dataset = np.eye(COND_num_classes)[df_train.values[:, -1].astype(int)]
    train_dataset = (
        tf.data.Dataset.from_tensor_slices((df_train.values[:, :-1], labels_dataset))
        .shuffle(BUFFER_SIZE)
        .batch(BATCH_SIZE, drop_remainder=True)
    )

    # GIVES CURRENT TRIAL A NAME - Suggestion: from parameters used
    training_title = f"eps{eps:.2f}"

    checkpoint_name_str = (
        checkpoint_dir
        + "/ckpt__"
        + str(training_title)
        + "__epoch="
        + str(EPOCHS - 1)
        + "__-"
        + str(EPOCHS)
    )

    print("#####################################################################")
    print(checkpoint_name_str)
    wandb.log({"checkpoint_name_str": checkpoint_name_str})
    print("#####################################################################")
    if not os.path.exists(checkpoint_name_str + ".index") or reload:
        train(train_dataset, training_title, False)
    else:
        checkpoint.restore(checkpoint_name_str)

    N_GEN_per_CLASS = np.int(N_GEN / COND_num_classes)
    N_GEN = np.int(N_GEN_per_CLASS * COND_num_classes)

    tf.random.set_seed(seed)
    noise_GEN = tf.random.normal([N_GEN, Z_DIM])
    labels_GEN = tf.Variable(
        np.tile(np.eye(COND_num_classes, dtype="float32"), (N_GEN_per_CLASS, 1)) * 1.0
    )

    images_GEN = generator([noise_GEN, labels_GEN], training=False)
    images_flat = layers.Flatten()(images_GEN)

    labels_flat = tf.Variable(tf.math.argmax(labels_GEN, 1)[:, None])

    log_iw = discriminator([images_GEN, labels_GEN], training=False)

    synth_data = pd.DataFrame(np.concatenate((images_flat, labels_flat), 1))

    return synth_data, log_iw
