import os, sys, importlib
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false'
import tensorflow as tf
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jr
import numpy as np
import pyarrow as pa 
import pandas as pd
import matplotlib.pyplot as plt 
import matplotlib as mpl
from scipy.linalg import circulant
import logging
import warnings
from comp import FedPMRECCompressor, PermKCompressor
from mask_models import LeNet5Masked, ResNet18Masked
warnings.filterwarnings('ignore', message='.*the v2.11\+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs.*')
logging.getLogger('tensorflow').setLevel(logging.ERROR)
os.environ['KMP_DUPLICATE_LIB_OK']='True'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '4'
warnings.filterwarnings('ignore')
np.set_printoptions(precision=6, suppress=True)
import copy
import ml_collections
import importlib
from keras import backend as K  # If using keras

def clear_memory():
    K.clear_session()  # Clears backend session
    tf.compat.v1.reset_default_graph()  # Resets the default graph (needed in TF1.x/compatibility mode)

from Utils import *
from Training import *
import wandb
import argparse
import multiprocessing

EARLY_STOPPING = False
EARLY_STOPPING_RATE = 1e-3

def inverse_sigmoid(x):
    return tf.math.log(x) - tf.math.log(1 - x)

def log_histograms(epoch, prior, posterior_update, suffix, run):
    run.log({
        'epoch': epoch,
        "prior_{}".format(suffix): wandb.Histogram(tf.sigmoid(prior).numpy()),
        "posteriors_{}".format(suffix): wandb.Histogram(tf.sigmoid(posterior_update).numpy())
    })
    return

def process_client(arg, central=False):
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    import keras
    import tensorflow as tf
    if not central:
        tf.config.set_visible_devices([], 'GPU')
    with tf.device('/CPU:0'):
        client, compressor, server_iteration, config, client_prev_noisy_posterior, client_models, client_prev_noisy_global = arg
        compressor.global_epoch = server_iteration
        if config.compressor.use_posterior_prior:
            return compressor.process(client_prev_noisy_posterior, client_models, client_id=client,
                                      project_blocks=config.compressor.project_block_kl_divergences, tf_models_provided=False)
        else:
            return compressor.process(client_prev_noisy_global, client_models, client_id=client,
                                      project_blocks=config.compressor.project_block_kl_divergences, tf_models_provided=False)

def trainables_tensor(model):
    return tf.concat([tf.reshape(var, [-1]) for var in model.trainable_variables], axis=0)

def main(config: ml_collections.ConfigDict, log_df: pd.DataFrame, procs=None):

    gpus = tf.config.experimental.list_physical_devices('GPU')
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Memory growth enabled for {len(gpus)} GPU(s).")
    except RuntimeError as e:
        print(e)

    if multiprocessing.get_start_method(allow_none=True) is None:
        multiprocessing.set_start_method('spawn')
    PERFECT_DOWNLINK = config.compressor.downlink_samples == 0
    PERFECT_UPLINK = config.compressor.uplink_samples == 0

    a = tf.random.uniform(shape=(1,len([1, 2, 3])), minval=0, maxval=10000000, dtype=tf.int32).numpy().tolist()
    print(a)

    physical_devices = tf.config.list_physical_devices()
    print(f"The physical devices are {physical_devices}")

    RNG = np.random.default_rng(config.seed)
    JKEY = jax.random.PRNGKey(config.seed) #! Should be parameterized

    os.makedirs('logs/', exist_ok=True)
    if config.beta.mode == 'index':
        beta_samples = np.logspace(start=-1, stop=np.log10(2), num=15, base=10)  # from 0.1 to 2
        beta = beta_samples[config.beta.index]
    elif config.beta.mode == 'value':
        beta = config.beta.value


    (x_train, y_train), (x_test, y_test), config.data.shape = load_and_preprocess_dataset(config.data.name, JKEY)

    def copy_model(model, compile=False):
        # tf.keras.backend.clear_session()
        model_copy = config.train.model()
        model_copy(tf.random.normal([2, *config.data.shape]))
        model_copy.set_weights(model)
        return model_copy

    test_dataset = batch_dataset(x_test, y_test, config.data.batch_size)

    print(f"The training data has a shape of {x_train.shape} and the test data has a shape of {x_test.shape}")
    print(f"The training labels has a shape of {y_train.shape} and the test labels has a shape of {y_test.shape}")

    local_epoch_distribution = generate_local_epoch_distribution(config.worker.num, RNG, config.worker.epoch.type, config.worker.epoch.is_random, config.server.num_epochs,
                                                                        config.worker.epoch.mean, config.worker.epoch.std, config.worker.epoch.beta, config.worker.epoch.coef)
    print(f"The local epoch distribution is {local_epoch_distribution.shape} shaped")
    print(f"The local epoch distribution is {local_epoch_distribution}")
    seperated_index_by_label, seperated_data_by_label = split_data_by_labels(x_train, y_train)
    validation_index, validation_data = sample_data_per_label(config.data.num_validation, RNG, x_train, seperated_index_by_label)
    client_labels, client_data = allocate_client_datasets(config.worker.num, RNG, config.data.alloc_type, config.data.alloc_ratio, seperated_data_by_label, config.data.beta, config.data.shape)
    # label_number = len(seperated_data_by_label)

    #*##########################################################
    if config.worker.num <= 20:
        plt.figure(1, figsize=(int(8 * config.worker.num / 10), 8))
    else:
        plt.figure(1, figsize=(int(4 * config.worker.num / 10), 16))
    for client in range(len(client_labels)):
        if config.worker.num <= 20:
            plt.subplot(5, int(np.ceil(config.worker.num / 5)), client + 1)
        else:
            plt.subplot(10, int(np.ceil(config.worker.num / 10)), client + 1)
        plt.hist(client_labels[client], color="lightblue", ec="red", align="left", bins=np.arange(11))
        plt.title("Client " + str(client + 1))
    plt.suptitle("Label Distributions of Clients")
    plt.tight_layout()
    plt.savefig('logs/'+'Client_Histogram.png')
    #*##########################################################

    active_client_matrix = generate_active_client_matrix(config.worker.inact_prob, RNG, config.server.num_epochs, config.worker.num)
    client_optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=config.worker.tx.lr)
    federator_optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=config.worker.tx.lr)
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
    loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
    accuracy_func = tf.keras.metrics.SparseCategoricalAccuracy()

    inputs = tf.keras.Input(shape=config.data.shape, dtype='float32', name="Input")
    federator_model = config.train.model()
    dummy_input = tf.random.normal([2, *config.data.shape])  # Batch size of 1 for testing
    federator_model(dummy_input)
    print("Total numer of parameters: ", np.sum([tf.keras.backend.count_params(w) for w in federator_model.trainable_weights]))
    log_dir = "logs/profile"
    file_writer = tf.summary.create_file_writer(log_dir)

    # federator_model = model
    federator_weights = federator_model.get_weights()
    previous_gradients = [tf.zeros_like(var) for var in federator_weights]
    gradients = np.array([])
    client_properties = np.zeros((config.server.num_epochs, 2))
    validation_properties = np.zeros((config.server.num_epochs, 2))
    test_properties = np.zeros((config.server.num_epochs, 2))
    wait, patience = 0, 3

    ul_comm = 0
    dl_comm = 0
    comm = 0

    logged_config = flatten_dict(config.to_dict())
    run = wandb.init(project=config.wandb.project, job_type=config.wandb.job_type,
                     name=config.wandb.name, config=logged_config)

    alphas, betas = 1, 1

    num_parameters = sum(np.size(layer) for layer in federator_model.trainable_weights)
    permk_compressor = PermKCompressor(config.worker.num, num_parameters, ef=None, clipping=None)
    client_compressors = dict()
    client_prev_noisy_global = dict()
    client_prev_noisy_posterior = dict()
    client_prev_prev_noisy_posterior = dict()
    client_models = dict()
    client_priors = dict()
    tmp_model = copy_model(federator_weights)

    max_test_acc = 0

    mask_samples = -1 # To indicate the first iteration

    compressor = FedPMRECCompressor(use_indiv_reference=False, num_indices=config.compressor.uplink_samples, no_compress=PERFECT_UPLINK, adaptive=config.compressor.adaptive_blocks_ul, kl_rate=config.compressor.kl_rate_ul,
                                    adaptive_avg=config.compressor.adaptive_avg, use_indices_immediately=(not config.compressor.reuse_samples), num_samples=config.compressor.sample_size, block_size=config.compressor.block_size, max_block_size=config.compressor.max_block_size)
    compressor.init(federator_model)

    if config.compressor.reuse_samples:
        compressor.update_blocks = False

    client_momentums = dict()
    for client in range(config.worker.num):
        client_prev_noisy_global[client] = copy_model(federator_weights)
        client_priors[client] = copy_model(federator_weights)
        client_prev_noisy_posterior[client] = copy_model(federator_weights)
        client_prev_prev_noisy_posterior[client] = copy_model(federator_weights)
        client_models[client] = copy_model(federator_weights)
        client_momentums[client] = copy_model(federator_weights)
        client_compressors[client] = FedPMRECCompressor(num_indices=config.compressor.downlink_samples, no_compress=PERFECT_DOWNLINK, adaptive=config.compressor.adaptive_blocks_ul, kl_rate=config.compressor.kl_rate_dl, adaptive_avg=config.compressor.adaptive_avg, use_indices_immediately=(not config.compressor.reuse_samples), num_samples=config.compressor.sample_size, block_size=config.compressor.block_size, max_block_size=config.compressor.max_block_size)
        client_compressors[client].init(client_prev_noisy_global[client])
        if config.compressor.reuse_samples:
            client_compressors[client].update_blocks = False

    sample_weights = list()
    for server_iteration in range(config.server.num_epochs):
        log_row = pd.Series()
        print(f'Federated Learning iteration {server_iteration + 1} ...')
        epoch_start_time = tf.timestamp()
        active_client_count = 0
        client_loss = tf.Variable(0, dtype=tf.float32)
        client_accuracy = tf.Variable(0, dtype=tf.float32)
        # aggregated_models = [tf.zeros_like(var) for var in federator_weights]
        federator_model.set_weights(federator_weights)

        avg_block_kl_divergences_downlink = list()
        kl_divergences_uplink = list()
        kl_divergences_downlink = list()
        avg_block_sizes_uplink = list()
        avg_block_sizes_downlink = list()

        aggregated_model = [tf.zeros_like(var) for var in federator_weights]

        # Sample the same global model estimate for each of the clients an update the prior accordingly
        if config.compressor.common_dl_prior and server_iteration != 0:
            client_compressors[0].global_epoch = server_iteration
            if config.compressor.reuse_samples and config.compressor.uplink_samples > 0:
                mask_samples = copy.deepcopy(sample_weights)
            else:
                mask_samples, kls_downlink, prior_downlink, posterior_downlink, ids, block_kls, block_sizes, new_ids = client_compressors[0].process(client_priors[client], federator_model, client_id=client)
            client_samples = copy.deepcopy(mask_samples)
            masks_structures = list()
            for m in mask_samples:
                masks_structures.append(client_compressors[0].structure_mask(m))

            if not config.compressor.reuse_samples:
                avg_block_sizes_downlink.append(np.mean(block_sizes))
                avg_block_kl_divergences_downlink.append(np.mean(block_kls))
            for client in range(config.worker.num):

                if config.train.compute_divs:
                    print("Downlink KLs ", client, " ")
                    print("Noisy posterior: ", compressor.compute_model_kls(client_prev_noisy_posterior[client], federator_model))
                    print("Noisy real posterior: ", compressor.compute_model_kls(client_models[client], federator_model))
                    print("Noisy global: ", compressor.compute_model_kls(client_prev_noisy_global[client], federator_model))

                if not config.compressor.reuse_samples:
                    # Transfer the global model through importance sampling using the previously transmitted global model as referenc
                    log_histograms(server_iteration, prior_downlink, posterior_downlink, "downlink", run)
                    kl_divergences_downlink.append(kls_downlink)
                # All client models get assigned the same compressed federator model (stored in client_samples)
                client_compressors[client].aggregate_and_update(client_samples, client_models[client], reset=config.compressor.reset_aggregation)

                # The common global model is set as prior
                client_prev_noisy_global[client].set_weights(client_models[client].get_weights())
                client_priors[client].set_weights(client_prev_noisy_global[client].get_weights())
                if not PERFECT_DOWNLINK:
                    if not config.compressor.reuse_samples:
                        c = config.compressor.downlink_samples * len(block_kls) * np.log2(client_compressors[client].num_samples)
                    else:
                        c = c_ul * (config.worker.num-1)/config.worker.num**2 # only send the indices to all clients but the one from which they originated
                else:
                    c = num_parameters * 32
                dl_comm += c
                comm += c

        avg_block_kl_divergences_uplink = list()
        client_ids_uplink = list()

        if server_iteration != 0 and not config.compressor.common_dl_prior:
            # Downlink Transmission
            arguments = []
            with tf.device('/CPU:0'):
                for client in range(config.worker.num):
                    if active_client_matrix[server_iteration][client]:
                        arguments.append((client, client_compressors[client], server_iteration, config, trainables_tensor(client_prev_noisy_posterior[client]), trainables_tensor(federator_model), trainables_tensor(client_priors[client])))

                num_procs = int(procs) if procs is not None else config.worker.num
                if num_procs > 1:
                    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
                    with multiprocessing.Pool(processes=num_procs, maxtasksperchild=1) as pool:
                        results = pool.map(process_client, arguments)
                else:
                    results = list()
                    for arg in arguments:
                        results.append(process_client(arg, True))
                    # Transfer the updated local model back to the federator

        for client in range(config.worker.num):
            if active_client_matrix[server_iteration][client]:
                print(f'Training on Client {client} ...')
                active_client_count += 1
                data, label = client_data[client], client_labels[client]
                training_dataset = batch_dataset(data, label, config.data.batch_size)

                if server_iteration != 0 and not config.compressor.common_dl_prior: # In the beginning we assume to have the global model exactly due to shared randomness
                    # Transfer the global model through importance sampling using the previously transmitted global model as reference
                    client_compressors[client].global_epoch = server_iteration

                    if config.train.compute_divs:
                        ptf = compressor.compute_model_kls(client_prev_noisy_posterior[client], federator_model)
                        gtf = compressor.compute_model_kls(client_prev_noisy_global[client], federator_model)
                        print("Downlink KLs ", client, " ")
                        print("Noisy posterior: ", ptf)
                        print("Noisy global: ", gtf)
                        log_row = log(log_row, server_iteration, {f"DivFtGlob{client}": gtf})
                        log_row = log(log_row, server_iteration, {f"DivFtPost{client}": ptf})

                        if server_iteration > 1:
                            tmp = tf.nest.map_structure(lambda x, y, z: inverse_sigmoid(
                                tf.add(tf.multiply(tf.sigmoid(x), 0.1), tf.multiply(
                                    tf.clip_by_value(tf.subtract(tf.sigmoid(y), tf.multiply(tf.sigmoid(z), 0.00)),
                                                     clip_value_min=0.01, clip_value_max=0.99), 0.9))),
                                                        client_prev_noisy_posterior[client].trainable_variables,
                                                        client_prev_noisy_global[client].trainable_variables,
                                                        client_prev_prev_noisy_posterior[client].trainable_variables)
                            # Assign new values to the trainable variables
                            for var, new_value in zip(tmp_model.trainable_variables, tmp):
                                var.assign(new_value)
                            mtf = compressor.compute_model_kls(tmp_model, federator_model)
                            print("Weighted combination: ", mtf)
                            log_row = log(log_row, server_iteration, {f"DivFtMix{client}": compressor.compute_model_kls(tmp_model, federator_model)})

                    mask_samples, kls_downlink, prior_downlink, posterior_downlink, ids, block_kls, block_sizes, new_ids = results[client]
                    client_samples = copy.deepcopy(mask_samples)

                    if config.train.compute_divs: log_histograms(server_iteration, prior_downlink, posterior_downlink, "downlink", run)
                    kl_divergences_downlink.append(kls_downlink)
                    avg_block_sizes_downlink.append(np.mean(block_sizes))
                    avg_block_kl_divergences_downlink.append(np.mean(block_kls))

                    # Update the client models with individual samples from federator model
                    if config.compressor.split_dl:
                        client_compressors[client].aggregate_and_update(client_samples, tmp_model, reset=config.compressor.reset_aggregation)
                        sparsified_model, indices, updates = permk_compressor.compress(tmp_model.get_weights(), client_prev_noisy_global[client].get_weights())
                        client_models[client].set_weights(sparsified_model)
                    else:
                        client_compressors[client].aggregate_and_update(client_samples, client_models[client], reset=config.compressor.reset_aggregation)

                    client_prev_noisy_global[client].set_weights(client_models[client].get_weights())

                    if config.train.compute_divs:
                        print("How noisy is the global: ", compressor.compute_model_kls(client_prev_noisy_global[client], federator_model))
                        log_row = log(log_row, server_iteration, {f"DivFtGlob{client}": compressor.compute_model_kls(client_prev_noisy_global[client], federator_model)})
                    if not PERFECT_DOWNLINK:
                        if not config.compressor.split_dl:
                            c = config.compressor.downlink_samples * len(block_kls) * np.log2(client_compressors[client].num_samples)
                        else:
                            c = config.compressor.downlink_samples * len(block_kls) * np.log2(client_compressors[client].num_samples) / config.worker.num
                    else:
                        c = num_parameters * 32
                    dl_comm += c
                    comm += c

                for var in client_optimizer.variables():
                   var.assign(tf.zeros_like(var))

                client_optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=config.worker.tx.lr)

                if server_iteration == 0:
                    masks_structures = -1
                else:
                    masks_structures = list()
                    for m in mask_samples:
                        masks_structures.append(client_compressors[0].structure_mask(m))
                reference_model = copy_model(client_models[client].get_weights())
                _, local_loss, local_acc = client_update(client_models[client], training_dataset, loss_func, client_optimizer, loss_metric, accuracy_func,
                                                         local_epoch_distribution[server_iteration, client], masks=None, tmp_model=reference_model, compressor=client_compressors[client], config=config)

                best_optim = 1
                if server_iteration != 1 and hasattr(config.compressor, "optimize_prior") and config.compressor.optimize_prior and server_iteration > 0:
                    min_kl_divergence = np.inf
                    for optim in np.arange(0.8, 1.0001, 0.01):
                        new_prior = tf.nest.map_structure(lambda x, y: inverse_sigmoid(
                            tf.add(tf.multiply(tf.sigmoid(x), optim), tf.multiply(tf.sigmoid(y), 1 - optim))),
                                                          client_prev_noisy_global[client].get_weights(),
                                                          client_prev_noisy_posterior[client].get_weights())
                        # Assign new values to the trainable variables
                        tmp_model.set_weights(new_prior)
                        kl_div = compressor.compute_model_kls(tmp_model, client_models[client])
                        print("Lambda: ", optim, "->", kl_div)
                        if kl_div < min_kl_divergence:
                            best_optim = optim

                if best_optim == 1:
                    client_priors[client].set_weights(client_prev_noisy_global[client].get_weights())
                else:
                    new_prior = tf.nest.map_structure(lambda x, y: inverse_sigmoid(
                        tf.add(tf.multiply(tf.sigmoid(x), best_optim), tf.multiply(tf.sigmoid(y), 1 - best_optim))),
                                                      client_prev_noisy_global[client].get_weights(),
                                                      client_prev_noisy_posterior[client].get_weights())
                    client_priors[client].set_weights(new_prior)

                if config.train.compute_divs:
                    print("Model: ", client_models[client].get_weights()[0][0][0][0])
                    print("Progress to federator: ", compressor.compute_model_kls(federator_model, client_models[client]))
                    log_row = log(log_row, server_iteration, {f"PtF{client}": compressor.compute_model_kls(federator_model, client_models[client])})
                    print("Progress to noisy global: ", compressor.compute_model_kls(client_prev_noisy_global[client], client_models[client]))
                    log_row = log(log_row, server_iteration, {f"PtG{client}": compressor.compute_model_kls(client_prev_noisy_global[client], client_models[client])})

                if config.compressor.project_kl_divergences != None and compressor.compute_model_kls(client_prev_noisy_global[client], client_models[client]) > config.compressor.project_kl_divergences:
                    new_posterior = compressor.project_model_onto_kl_ball(client_prev_noisy_global[client], client_models[client], epsilon=config.compressor.project_kl_divergences)
                    compressor.aggregate_and_update([new_posterior], client_models[client], reset=config.compressor.reset_aggregation)

                if config.train.compute_divs:
                    print("Progress to noisy global after projection: ", compressor.compute_model_kls(client_prev_noisy_global[client], client_models[client]))
                    log_row = log(log_row, server_iteration, {f"PtGProj{client}": compressor.compute_model_kls(client_prev_noisy_global[client], client_models[client])})

        sample_weights = list()
        arguments = []

        with tf.device('/CPU:0'):
            for client in range(config.worker.num):
                if active_client_matrix[server_iteration][client]:

                    if config.train.compute_divs:
                        print("Uplink KLs ", client, " ")
                        if server_iteration != 0:
                            print("Noisy posterior: ", compressor.compute_model_kls(client_prev_noisy_posterior[client],
                                                                                    client_models[client]))
                            log_row = log(log_row, server_iteration, {
                                f"DivCtPost{client}": compressor.compute_model_kls(client_prev_noisy_posterior[client],
                                                                                   client_models[client])})
                        print("Noisy global: ",
                              compressor.compute_model_kls(client_prev_noisy_global[client], client_models[client]))
                        log_row = log(log_row, server_iteration, {
                            f"DivCtGlob{client}": compressor.compute_model_kls(client_prev_noisy_global[client],
                                                                               client_models[client])})

                    arguments.append((client, compressor, server_iteration, config,
                                      trainables_tensor(client_prev_noisy_posterior[client]),
                                      trainables_tensor(client_models[client]),
                                      trainables_tensor(client_priors[client])))

            num_procs = int(procs) if procs is not None else config.worker.num
            if num_procs > 1:
                os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
                with multiprocessing.Pool(processes=num_procs, maxtasksperchild=1) as pool:
                    results = pool.map(process_client, arguments)
            else:
                results = list()
                for arg in arguments:
                    results.append(process_client(arg, True))

        c_ul = 0
        new_client_ids_uplink = list()
        for client in range(config.worker.num):
            if active_client_matrix[server_iteration][client]:
                client_samples, kls_uplink, prior_uplink, posterior_uplink, ids_uplink, block_kls_uplink, block_sizes_uplink, new_ids_uplink = results[client]

                client_prev_prev_noisy_posterior[client].set_weights(client_prev_noisy_posterior[client].get_weights())
                compressor.aggregate_and_update(client_samples, client_prev_noisy_posterior[client], reset=config.compressor.reset_aggregation)

                avg_block_kl_divergences_uplink.append(np.mean(block_kls_uplink))
                avg_block_sizes_uplink.append(np.mean(block_sizes_uplink))
                client_ids_uplink.append(ids_uplink)
                new_client_ids_uplink.append(new_ids_uplink)
                if not PERFECT_UPLINK:
                    c = config.compressor.uplink_samples * len(block_kls_uplink) * np.log2(compressor.num_samples)
                    c_ul += c
                    if config.compressor.adaptive_blocks_ul and compressor.update_blocks:
                        if not config.compressor.adaptive_avg:
                            c += len(block_sizes_uplink) * np.log2(compressor.max_block_size)
                        else:
                            c += np.log2(compressor.max_block_size)
                else:
                    c = num_parameters * 32
                ul_comm += c
                comm += c

                if config.train.compute_divs: log_histograms(server_iteration, prior_uplink, posterior_uplink, "uplink", run)
                kl_divergences_uplink.append(kls_uplink)
                sample_weights.extend(client_samples)
            else:
                print(f'Client {client} is inactive')
                pass

        avg_kls_ul = np.mean(avg_block_kl_divergences_uplink)
        std_kls_ul = np.std(avg_block_kl_divergences_uplink)

        avg_kls_dl = np.mean(avg_block_kl_divergences_downlink)
        std_kls_dl = np.std(avg_block_kl_divergences_downlink)

        avg_params_ul = np.mean(avg_block_sizes_uplink)
        avg_params_dl = np.mean(avg_block_kl_divergences_downlink)

        if config.compressor.adaptive_blocks_ul:
            if compressor.update_blocks:
                compressor.ids = compressor.aggregate_ids(new_client_ids_uplink, balance=config.compressor.adaptive_avg)
                if not config.compressor.adaptive_avg:
                    dl_comm += config.worker.num * len(compressor.ids) * np.log2(compressor.max_block_size)
                    comm += config.worker.num * len(compressor.ids) * np.log2(compressor.max_block_size)
                else:
                    dl_comm += config.worker.num * np.log2(compressor.max_block_size)
                    comm += config.worker.num * np.log2(compressor.max_block_size)
                compressor.update_blocks = False
                if config.compressor.adaptive_blocks_dl:
                    for client in range(config.worker.num):
                        client_compressors[client].ids = compressor.ids
                        client_compressors[client].update_blocks = False
            else:
                if avg_kls_ul > config.compressor.avg_dev_factor * compressor.kl_rate or avg_kls_ul < 1/config.compressor.avg_dev_factor * compressor.kl_rate:
                    compressor.update_blocks = True

        client_loss.assign(tf.divide(client_loss, active_client_count))
        client_accuracy.assign(tf.divide(client_accuracy, active_client_count))
        client_properties[server_iteration] = np.array([client_loss.numpy(), client_accuracy.numpy()])
        
        if (active_client_count >= (config.worker.num * config.server.update_thresh)):
            compressor.aggregate_and_update(sample_weights, federator_model, reset=config.compressor.reset_aggregation)
            if config.train.compute_divs:
                print("Federator model: ", federator_model.get_weights()[0][0][0][0])
                print("Federator progress to noisy global: ",
                      compressor.compute_model_kls(client_prev_noisy_global[client], federator_model))
            federator_weights = federator_model.get_weights()
            
            for class_label in validation_index.keys():
                class_validation_data = validation_data[class_label]
                class_validation_label = np.full(validation_index[class_label].size, class_label)
                
                logits = federator_model(class_validation_data, training=False)
                accuracy_func.update_state(class_validation_label, logits)
                loss_metric.update_state(class_validation_label, logits)

            validation_loss, validation_accuracy = loss_metric.result(), accuracy_func.result()
            validation_properties[server_iteration] = np.array([validation_loss.numpy(), validation_accuracy.numpy()])
            
            accuracy_func.reset_state()
            loss_metric.reset_state()
            #*##########################################################

            for batch, (data, label) in test_dataset.enumerate(start = 0):
                step_test(federator_model, data, label, loss_metric, accuracy_func)
            test_loss, test_acc = loss_metric.result(), accuracy_func.result()
            accuracy_func.reset_state()
            loss_metric.reset_state()

            test_properties[server_iteration] = np.array(test_loss.numpy(), test_acc.numpy())
            accuracy_func.reset_state()
            loss_metric.reset_state()
            
            print(f'Validation Loss and Validation Accuracy for iteration {server_iteration + 1} are {test_loss.numpy()} and {test_acc.numpy()}')
            print(f"Average Client Loss and Client Accuracy for iteration {server_iteration + 1} are {client_loss.numpy()} and {client_accuracy.numpy()}")


            log_row = log(log_row, server_iteration,{
                'epoch': server_iteration,
                'avg_val_loss_clients': client_loss.numpy(),
                'avg_val_acc_clients': client_accuracy.numpy(),
                'val_loss': validation_loss.numpy(),
                'val_accuracy': validation_accuracy.numpy(),
                'test_loss': test_loss.numpy(),
                'test_accuracy': test_acc.numpy(),
                'dl_comm': dl_comm,
                'ul_comm': ul_comm,
                'comm': comm,
                'avg_params_dl': avg_params_dl,
                'avg_params_ul': avg_params_ul,
                'dl_bitrate': dl_comm / num_parameters / config.worker.num / (server_iteration+1),
                'ul_bitrate': ul_comm / num_parameters / config.worker.num / (server_iteration+1)
            })


            if test_acc.numpy() > max_test_acc:
                max_test_acc = test_acc.numpy()

            print(avg_kls_dl, avg_kls_ul, avg_params_ul, avg_params_dl)

            log_row = log(log_row, server_iteration,{"Client {} KL Uplink".format(idx): kl for idx, kl in enumerate(kl_divergences_uplink)})
            log_row = log(log_row, server_iteration,{"Client {} KL Downlink".format(idx): kl for idx, kl in enumerate(kl_divergences_downlink)})
            log_row = log(log_row, server_iteration,{
                "Avg KL Uplink:": np.mean(kl_divergences_uplink),
                "Avg KL Downlink:": np.mean(kl_divergences_downlink),
                "Avg Block KL Downlink:": avg_kls_dl,
                "Std Block KL Downlink:": std_kls_dl,
                "Avg Block KL Uplink:": avg_kls_ul,
                "Std Block KL Uplink:": std_kls_ul})

            if EARLY_STOPPING and server_iteration:
                if test_properties[server_iteration - 1, 0] - test_properties[server_iteration, 0] < EARLY_STOPPING_RATE:
                    wait += 1
                else:
                    wait = 0
                if wait >= patience:
                    break
        else:
            print(f"Server Iteration {server_iteration + 1} is neglected due to insufficient active clients.")
        
        epoch_end_time = tf.timestamp()
        print(f'Federated Learning iteration {server_iteration + 1} is completed in {epoch_end_time - epoch_start_time} seconds.')

        if log_df.empty:
            log_df = log_row.to_frame().T
        else:
            log_df.loc[len(log_df)] = log_row

        if server_iteration == config.server.num_epochs-1:
            log_row = log(log_row, server_iteration, {
                "Best Test Accuracy:": max_test_acc,
                "Average Bitrate": comm / num_parameters / config.worker.num / config.server.num_epochs,
                "Average Bitrate UL": ul_comm / num_parameters / config.worker.num / config.server.num_epochs,
                "Average Bitrate DL": dl_comm / num_parameters / config.worker.num / config.server.num_epochs})

    x_range = np.arange(1, gradients.size + 1)
    plt.figure(2)
    plt.plot(x_range, gradients)
    plt.title('Gradients vs Epoch')
    plt.xlabel('Epoch Number')
    plt.ylabel('Gradient Norm')
    plt.grid(visible=True)
    plt.savefig('logs/'+f'Grad vs Epoch.png')

    log_df['Run'] = config.wandb.name
    wandb.config.update(config.to_dict(), allow_val_change=True)
    run.finish()

    return federator_model, log_df

def parse_arguments():
    parser = argparse.ArgumentParser(description=".")
    parser.add_argument('--dir', type=str, help='', default="Params")
    parser.add_argument('--include', nargs='+', default=None, help='')
    parser.add_argument('--exclude', nargs='+', default=None, help='')
    parser.add_argument('--procs', type=str, default=None, help='')
    parser.add_argument('--seed', type=str, default=-1, help='')
    return parser.parse_args()

if __name__ == "__main__":
    sys.path.append(os.path.dirname(os.path.abspath(__file__)))
    args = parse_arguments()
    configs = load_and_extract_configs(dir_path=args.dir, include=args.include, exclude=args.exclude)

    for idx, config in enumerate(configs):
        print(f"Running with Config #{idx + 1}...")
        clear_memory()
        for run in range(int(config.train.num_runs)):
            try:
                if args.seed == -1:
                    config.seed = run
                else:
                    config.seed = int(args.seed)
                print(f"Runnning with seed {config.seed}")
                run_name = get_run_name(config, run + 1)
                config.wandb.name = run_name
                curr_path = os.path.dirname(os.path.realpath(__file__))
                log_df = pd.DataFrame()
                print(f"Run #{run + 1} with Config #{idx + 1} started...")
                federator_start_time = datetime.now()
                federator_model, log_df = main(config, log_df, args.procs)
                federator_end_time = datetime.now()
                print(
                    f'The federated learning process is completed in {(federator_end_time - federator_start_time).total_seconds()} seconds.')
                path = os.path.join('logs', run_name)
                path.replace(" ", "\\ ")
                file_path = r"""{}.csv""".format(path)
                directory = os.path.dirname(file_path)
                if not os.path.exists(directory):
                    os.makedirs(directory)
                log_df.to_csv(file_path, index=False)
            except Exception as inst:
                print(type(inst))    # the exception type
                print(inst.args)     # arguments stored in .args
                print(inst)
                print("Run not successful!")
                continue
