import copy

import tensorflow as tf
import numpy as np

@tf.function
def batch_dataset(data, label, batch_length) -> tf.data.Dataset:
    print('Tracing batch_dataset')
    dataset_size = tf.size(label, out_type=tf.int64)
    dataset = tf.data.Dataset.from_tensor_slices((data, label))
    dataset = dataset.shuffle(dataset_size, reshuffle_each_iteration=True)
    dataset = dataset.batch(batch_length, drop_remainder=False)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

@tf.function
def get_batched_dataset_specs(dataset):
    print('Tracing get_batched_dataset_specs')
    data, labels = dataset.take(1).get_single_element()
    batch_shape = tf.shape(data)
    batch_length, sample_width, sample_height = batch_shape[0], batch_shape[1], batch_shape[2]
    batch_number = tf.data.experimental.cardinality(dataset)
    last_batch_length = tf.data.experimental.cardinality(dataset) % tf.cast(batch_length, tf.int64)
    return sample_width, sample_height, batch_number, last_batch_length


@tf.function
def step_train(client_model, data, label, loss_func):
    print('Tracing step_train')
    with tf.GradientTape() as tape:
        logits = client_model(data, training=True)
        losses = loss_func(label, logits)
    gradients = tape.gradient(losses, client_model.trainable_variables)
    return gradients, logits

@tf.function
def step_test(model, data, label, loss_metric, accuracy_metric):
    print('Tracing step_test...')
    logits = model(data, training=False)
    loss_metric.update_state(label, logits)
    accuracy_metric.update_state(label, logits)

def apply_mirror_descent_update(variable, gradient, learning_rate):
    exp_grad = tf.exp(-learning_rate * gradient)
    new_value = variable * exp_grad
    return new_value

# @tf.function
def client_update(client_model, training_dataset, loss_func, optimizer, loss_metric, accuracy_func, local_epoch_number, masks=None, client_idx=-1, verbose=False, tmp_model=None, compressor=None, config=None):
    # @tf.function
    def local_update(client_model, dataset, loss_func, optimizer, loss_metric, accuracy_func, epoch_number,
                     batch_length, last_batch_length, masks=None, client_idx=client_idx, verbose=verbose, md=False):
        tmp_weights = client_model.get_weights()
        batch_number = tf.data.experimental.cardinality(dataset)
        mask_idx = 0

        for epoch in range(int(epoch_number)):
            gradients = [tf.zeros_like(var) for var in client_model.trainable_variables]

            for batch, (data, label) in dataset.enumerate(start=0):
                if (batch < 1 or batch == 45) and verbose:
                    print(batch)
                    #if masks is not -1 and masks is not None:
                    #    print(mask[list(mask.keys())[0]][0][0][0])
                if masks is not None and masks is not -1 and epoch < 1:
                    mask = masks[mask_idx % len(masks)]
                    mask_idx += 1
                    with tf.GradientTape() as tape:
                        if batch < 1 or batch == 45:
                            logits = client_model(data, training=True, masks=mask) # verbose=True
                        else:
                            logits = client_model(data, training=True, masks=mask)
                        losses = loss_func(label, logits)
                        temp_gradients = tape.gradient(losses, client_model.trainable_variables)
                else:
                    with tf.GradientTape() as tape:
                        if batch < 1 or batch == 45:
                            logits = client_model(data, training=True) # verbose=True
                        else:
                            logits = client_model(data, training=True)
                        losses = loss_func(label, logits)
                        temp_gradients = tape.gradient(losses, client_model.trainable_variables)

                temp_gradients = tf.cond(tf.less(batch, batch_number),
                                         lambda: temp_gradients,
                                         lambda: tf.nest.map_structure(
                                             lambda x: tf.multiply(tf.divide(x, tf.cast(batch_length, tf.float32)),
                                                                   tf.cast(last_batch_length, tf.float32)),
                                             temp_gradients))
                if (batch < 1 or batch == 45) and verbose:
                    #if masks is not -1 and masks is not None:
                    #    print(mask[list(mask.keys())[0]][0][0][0])
                    tf.print("Gradient: ", temp_gradients[0][0][0][0])

                if masks is not None and epoch == -1:
                    gradients = tf.nest.map_structure(lambda x, y: tf.add(x, y), temp_gradients, gradients)
                else:
                    if md:
                        # Apply the mirror descent update with Adam
                        for var, grad in zip(client_model.trainable_variables, temp_gradients):
                            new_value = apply_mirror_descent_update(var, grad, optimizer.learning_rate)
                            var.assign(new_value)
                    else:
                        optimizer.apply_gradients(zip(temp_gradients, client_model.trainable_variables))
                        if config.compressor.project_kl_divergences != None and compressor.compute_model_kls(tmp_model, client_model) > config.compressor.project_kl_divergences:
                            new_posterior = compressor.project_model_onto_kl_ball(tmp_model, client_model, epsilon=1000)
                            compressor.aggregate_and_update([new_posterior], client_model, reset=True)
                            break

            if masks is not None and epoch == -1:
                optimizer.apply_gradients(zip(gradients, client_model.trainable_variables))
                if verbose:
                    print("Batch", batch)
                    tf.print("Total gradients: ", gradients[0][0][0][0])

        return gradients
    sample_width, sample_height, batch_number, last_batch_length = get_batched_dataset_specs(training_dataset)
    gradients = local_update(client_model, training_dataset, loss_func, optimizer, loss_metric, accuracy_func, local_epoch_number, batch_number, last_batch_length, masks=masks) #* Theta
    return 0, 0, 0

def client_update_classical(client_model, training_dataset, loss_func, optimizer, loss_metric, accuracy_func, local_epoch_number, masks=None, client_idx=-1, verbose=False, tmp_model=None, compressor=None, config=None):
    # @tf.function
    def local_update(client_model, dataset, loss_func, optimizer, loss_metric, accuracy_func, epoch_number,
                     batch_length, last_batch_length, masks=None, client_idx=client_idx, verbose=verbose, md=False, tmp_model=None):
        tmp_weights = tmp_model.get_weights()
        batch_number = tf.data.experimental.cardinality(dataset)

        for epoch in range(int(epoch_number)):
            gradients = [tf.zeros_like(var) for var in client_model.trainable_variables]

            for batch, (data, label) in dataset.enumerate(start=0):
                if (batch < 1 or batch == 45) and verbose:
                    print(batch)
                with tf.GradientTape() as tape:
                    logits = client_model(data, training=True)
                    losses = loss_func(label, logits)
                    temp_gradients = tape.gradient(losses, client_model.trainable_variables)

                temp_gradients = tf.cond(tf.less(batch, batch_number),
                                         lambda: temp_gradients,
                                         lambda: tf.nest.map_structure(
                                             lambda x: tf.multiply(tf.divide(x, tf.cast(batch_length, tf.float32)),
                                                                   tf.cast(last_batch_length, tf.float32)),
                                             temp_gradients))
                if (batch < 1 or batch == 45) and verbose:
                    tf.print("Gradient: ", temp_gradients[0][0][0][0])

                optimizer.apply_gradients(zip(temp_gradients, client_model.trainable_variables))
                if config.compressor.project_kl_divergences != None and compressor.compute_model_kls(tmp_model, client_model) > config.compressor.project_kl_divergences:
                    new_posterior = compressor.project_model_onto_kl_ball(tmp_model, client_model, epsilon=1000)
                    compressor.aggregate_and_update([new_posterior], client_model, reset=True)
                    break

            if masks is not None and epoch == -1:
                optimizer.apply_gradients(zip(gradients, client_model.trainable_variables))
                if verbose:
                    print("Batch", batch)
                    tf.print("Total gradients: ", gradients[0][0][0][0])

        # compute the overall gradient as model btw initial model and final model after update
        gradients = [tf.subtract(x, y) for x, y in zip(tmp_weights, client_model.get_weights())]

        return gradients
    sample_width, sample_height, batch_number, last_batch_length = get_batched_dataset_specs(training_dataset)
    gradients = local_update(client_model, training_dataset, loss_func, optimizer, loss_metric, accuracy_func, local_epoch_number, batch_number, last_batch_length, masks=masks, tmp_model=tmp_model) #* Theta
    return gradients, 0, 0
