import random
import numpy as np
import tensorflow as tf
from scipy.stats import bernoulli
import wandb
from scipy.optimize import minimize
import cvxpy as cp
import time

class Compressor:
    def __init__(self, ef, clipping):
        self.ef = ef
        self.clipping = clipping
        pass

    def process(self, gradient):
        return self.decompress(self.compress(gradient))

    def compress(self, gradient):
        return

    def decompress(self, compressed_gradient):
        return compressed_gradient

class NoCompress(Compressor):
    def __init__(self):
        self.ef = 0
        self.clipping = 0
        pass

    def process(self, gradient, model_old, model_new):
        return self.decompress(self.compress(gradient))

    def compress(self, gradient):
        return gradient

    def decompress(self, compressed_gradient):
        return compressed_gradient

class GradientProjection(Compressor):
    def __init__(self, seeds, scaling):
        self.seeds = seeds
        self.scaling = scaling

    def process(self, gradient):
        return self.decompress(self.compress(gradient)[0])

    def compress(self, gradient):
        # Flatten provided gradients
        flat_gradients = tf.concat([tf.reshape(grad, [-1]) for grad in gradient], axis=0)
        gradient_estimates = tf.zeros_like(flat_gradients)

        coeffs_and_zs = []

        for seed in self.seeds:
            tf.random.set_seed(seed)
            # Generate random direction z
            z = tf.random.normal(flat_gradients.shape)
            z = z / tf.norm(z)
            # Project the provided (flattened) gradients onto z
            proj = tf.tensordot(flat_gradients, z, axes=1) * self.scaling
            gradient_estimates += proj * z
            coeffs_and_zs.append((proj.numpy(), z.numpy()))

        gradient_estimates /= len(self.seeds)

        original_shapes = [g.shape for g in gradient]
        sizes = [tf.reduce_prod(shape).numpy() for shape in original_shapes]
        splits = tf.split(gradient_estimates, sizes)
        reshaped_gradients = [tf.reshape(split, shape) for split, shape in zip(splits, original_shapes)]

        return reshaped_gradients, None


@tf.function
def flatten_grad(grad):
    temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
    for g in grad:
        temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)),)))
    return temp.concat()


class FlatCompressor(Compressor):
    def __init__(self, rescale=False, **kwargs):
        super().__init__(**kwargs)
        self.rescale = rescale

    # @tf.function
    def compress(self, gradient):
        # Flatten provided gradients
        flat_gradients = flatten_grad(gradient)

        if self.rescale: original_norm = tf.norm(flat_gradients)

        if self.ef is not None and self.accumulated_error is not None:
            flat_gradients += self.ef * self.accumulated_error

        compressed_grads = self.flatcompress(flat_gradients)

        if self.rescale:
            compressed_grads *= original_norm / tf.norm(compressed_grads)

        if self.ef is not None:
            self.accumulated_error = flat_gradients - compressed_grads

        original_shapes = [g.shape for g in gradient]
        sizes = [tf.reduce_prod(shape) for shape in original_shapes]
        splits = tf.split(compressed_grads, sizes)
        reshaped_gradients = [tf.reshape(split, shape) for split, shape in zip(splits, original_shapes)]

        return reshaped_gradients


class TopKCompressor(FlatCompressor):
    def __init__(self, K, **kwargs):
        super().__init__(**kwargs)
        self.K = K

    @tf.function
    def flatcompress(self, flat_gradients):
        tkresult = tf.math.top_k(tf.math.abs(flat_gradients), k=self.K)
        indices = tf.reshape(tkresult.indices, (-1, 1))
        values = tf.gather_nd(flat_gradients, indices)
        return tf.scatter_nd(indices=indices, updates=values, shape=tf.shape(flat_gradients))


class RandKCompressor(FlatCompressor):
    def __init__(self, K, **kwargs):
        super().__init__(**kwargs)
        self.K = K

    @tf.function
    def flatcompress(self, flat_gradients):
        mask = tf.random.shuffle(tf.concat([tf.ones((self.K)), tf.zeros(len(flat_gradients) - self.K)], 0))
        compressed_grads = tf.multiply(flat_gradients, mask)
        return compressed_grads


class SignSGD(FlatCompressor):
    def __init__(self, **kwargs): super().__init__(**kwargs)

    def flatcompress(self, flat_gradient): return tf.sign(flat_gradient)

class PermKCompressor(FlatCompressor):
    def __init__(self, n, d, **kwargs):
        super().__init__(**kwargs)
        self.n = 1#n  #* number of clients
        self.d = d  #* number of parameters
        self.K = d #// n
        self.client = tf.Variable(0, trainable=False)
        self.permutation = tf.Variable(tf.random.shuffle(tf.range(self.d)), trainable=False)
        self.new_permutation = tf.Variable(False, trainable=False)

    @tf.function
    def flatcompress(self, flat_gradients, flat_old_model):
        if self.new_permutation:
            self.permutation.assign(tf.random.shuffle(tf.range(self.d)))
            self.new_permutation.assign(False)

        empty_grads = flat_old_model # tf.zeros_like(flat_gradients, dtype=flat_gradients.dtype)
        indices = tf.gather(self.permutation, tf.range(self.client * self.K, (self.client + 1) * self.K))
        indices = tf.reshape(indices, (-1, 1))
        updates = tf.gather(flat_gradients, indices[:, 0])
        res_flattened = tf.tensor_scatter_nd_update(empty_grads, indices, updates)

        self.client.assign((self.client + 1) % self.n)
        if self.client == 0:
            self.new_permutation.assign(True)

        return res_flattened, indices, updates

    def compress(self, gradient, old_model):
        # Flatten provided gradients
        flat_gradients = flatten_grad(gradient)
        flat_old_model = flatten_grad(old_model)

        if self.rescale: original_norm = tf.norm(flat_gradients)

        if self.ef is not None and self.accumulated_error is not None:
            flat_gradients += self.ef * self.accumulated_error

        compressed_grads, indices, updates = self.flatcompress(flat_gradients, flat_old_model)

        if self.rescale:
            compressed_grads *= original_norm / tf.norm(compressed_grads)

        if self.ef is not None:
            self.accumulated_error = flat_gradients - compressed_grads

        original_shapes = [g.shape for g in gradient]
        sizes = [tf.reduce_prod(shape) for shape in original_shapes]
        splits = tf.split(compressed_grads, sizes)
        reshaped_gradients = [tf.reshape(split, shape) for split, shape in zip(splits, original_shapes)]

        return reshaped_gradients, indices, updates


def sample_mask(self, mask_probs: dict) -> list:
    sampled_mask = []
    for layer_name, layer in mask_probs.items():
        if 'mask' in layer_name:
            theta = tf.sigmoid(layer).numpy()  # Using TensorFlow's sigmoid and converting to numpy
            updates_s = bernoulli.rvs(theta)
            updates_s = np.where(updates_s == 0, self.epsilon, updates_s)
            updates_s = np.where(updates_s == 1, 1 - self.epsilon, updates_s)
            sampled_mask.append(updates_s)
        else:
            sampled_mask.append(layer.numpy())  # Directly converting TensorFlow tensor to numpy
    return sampled_mask

def inverse_sigmoid(x):
    return tf.math.log(x) - tf.math.log(1 - x)


class FedPMRECCompressor(Compressor):
    def __init__(self, adaptive=False, kl_rate=1, num_samples=256, block_size=256, max_block_size=512, use_indiv_reference=False, num_indices=1, no_compress=False, adaptive_avg=False, use_indices_immediately=True):
        self.adaptive = adaptive
        self.adaptive_avg = adaptive_avg
        self.use_indices_immediately = use_indices_immediately
        self.kl_rate = kl_rate
        self.num_samples = num_samples
        self.block_size = block_size
        self.max_block_size = max_block_size
        self.update_blocks = True

        self.mask_shapes = []
        self.layer_names = []
        self.mask_shapes_known = False

        self.global_epoch = 0

        self.alphas = None
        self.betas = None

        self.reference = dict()
        self.use_indiv_reference = use_indiv_reference

        self.num_indices = num_indices
        self.no_compress = no_compress

        self.old_ids = None
        self.ids = None

    def log_histograms(self, prior, posterior_update):
        def numpy_conversion(prior, posterior_update):
            log_dict = {
                'epoch': self.global_epoch,
                "prior": wandb.Histogram(tf.sigmoid(prior).numpy()),
                'posteriors': wandb.Histogram(tf.sigmoid(posterior_update).numpy())
            }
            wandb.log(log_dict)
            return 0

        tf.py_function(func=numpy_conversion, inp=[prior, posterior_update], Tout=tf.float32)
        return

    # @tf.function
    def process(self, model_old, model_new, client_id = 0, project_blocks=None, tf_models_provided=True):
        compressed, kls, prior, posterior, ids, block_kls, block_sizes, new_ids = self.compress(model_old, model_new, client_id=client_id, project_blocks=project_blocks, tf_models_provided=tf_models_provided)
        return self.decompress(compressed), kls, prior, posterior, ids, block_kls, block_sizes, new_ids

    # @tf.function
    def update(self, aggregated_model, model):
        for layer in model.layers:
            if hasattr(layer, 'mask') and hasattr(layer, 'bias_mask'):
                layer.mask.assign(aggregated_model[layer.name])
                layer.bias_mask.assign(aggregated_model[layer.name+'_bias_mask'])
            elif hasattr(layer, 'mask'):
                layer.mask.assign(aggregated_model[layer.name])
        return 0



    # @tf.function
    def aggregate_and_update(self, sample_list, model, reset=True):
        sample_tensor = tf.stack(sample_list)
        sample_sum = tf.reduce_sum(sample_tensor, axis=0)

        if self.alphas == None or reset == True:
            self.alphas = 1 #tf.ones_like(sample_sum)
            self.betas = 1 #tf.ones_like(sample_sum)
        self.alphas += sample_sum
        self.betas += sample_tensor.shape[0] - sample_sum

        sample_sum = (self.alphas-1) / (self.alphas + self.betas - 2)

        sample_sum = inverse_sigmoid(sample_sum)

        aggregated_model = {}

        start_index = 0
        for shape, layer_name in zip(self.mask_shapes, self.layer_names):
            num_elements = tf.reduce_prod(shape)
            end_index = start_index + num_elements
            updated_mask_values = tf.gather(sample_sum, tf.range(start_index, end_index))
            updated_mask_values = tf.reshape(updated_mask_values, shape)
            aggregated_model[layer_name] = updated_mask_values
            start_index = end_index

        self.update(aggregated_model, model)
        return 0

    def structure_mask(self, sample):
        sample_sum = np.where(sample == 0.01, 0, 1)
        sample_sum = tf.convert_to_tensor(sample_sum, dtype=tf.float32)

        aggregated_model = {}

        start_index = 0
        for shape, layer_name in zip(self.mask_shapes, self.layer_names):
            num_elements = tf.reduce_prod(shape)
            end_index = start_index + num_elements
            # updated_mask_values = sample_sum[start_index:end_index]
            # print(shape, layer_name, end_index)
            updated_mask_values = tf.gather(sample_sum, tf.range(start_index, end_index))
            updated_mask_values = tf.reshape(updated_mask_values, shape)
            aggregated_model[layer_name] = updated_mask_values
            start_index = end_index
        return aggregated_model

    def init(self, model):
        # Iterate over both models' layers simultaneously
        if self.mask_shapes_known:
            return
        for model_layer in model.layers:
            if hasattr(model_layer, 'mask'):
                self.mask_shapes.append(model_layer.mask.shape)
                self.layer_names.append(model_layer.name)
            if hasattr(model_layer, 'bias_mask'):
                self.mask_shapes.append(model_layer.bias_mask.shape)
                self.layer_names.append(model_layer.name + '_bias_mask')
        self.mask_shapes_known = True  # After this they are known anyways

        model_flattened = tf.concat([tf.reshape(var, [-1]) for var in model.trainable_variables], axis=0)
        step = tf.range(0, len(model_flattened), self.block_size)
        ids = tf.unstack(step)
        if ids[-1] < len(model_flattened) - 1:
            ids.append(len(model_flattened))
        self.ids = ids

    def compute_model_kls(self, prior_model, posterior_model, verbose=False):
        prior_masks_concat = []
        posterior_masks_concat = []

        # Iterate over both models' layers simultaneously
        for prior_layer, posterior_layer in zip(prior_model.layers, posterior_model.layers):
            if not self.mask_shapes_known:
                self.init(prior_model)
            if hasattr(prior_layer, 'mask'):
                prior_masks_concat.append(tf.reshape(prior_layer.mask, [-1]))
                posterior_masks_concat.append(tf.reshape(posterior_layer.mask, [-1]))
            if hasattr(prior_layer, 'bias_mask'):
                prior_masks_concat.append(tf.reshape(prior_layer.bias_mask, [-1]))
                posterior_masks_concat.append(tf.reshape(posterior_layer.bias_mask, [-1]))

        posterior_update = tf.concat(posterior_masks_concat, axis=0)
        prior = tf.concat(prior_masks_concat, axis=0)

        if verbose:
            print(posterior_update[:10])
            print(prior[:10])

        prior = tf.sigmoid(prior)
        posterior_update = tf.sigmoid(posterior_update)

        kls = self.compute_kls(prior=tf.reshape(prior, [-1]), posterior=tf.reshape(posterior_update, [-1]))

        return tf.reduce_sum(kls)

    def project_model_onto_kl_ball(self, prior_model, posterior_model, epsilon, tol=1e-9, max_iter=1000, client_id=0, n_samples=1):
        compressed_model = {}
        gradient_updates = []

        prior_masks_concat = []
        posterior_masks_concat = []

        # Iterate over both models' layers simultaneously
        for prior_layer, posterior_layer in zip(prior_model.layers, posterior_model.layers):
            if not self.mask_shapes_known:
                self.init(prior_model)
            if hasattr(prior_layer, 'mask'):
                prior_masks_concat.append(tf.reshape(prior_layer.mask, [-1]))
                posterior_masks_concat.append(tf.reshape(posterior_layer.mask, [-1]))
            if hasattr(prior_layer, 'bias_mask'):
                prior_masks_concat.append(tf.reshape(prior_layer.bias_mask, [-1]))
                posterior_masks_concat.append(tf.reshape(posterior_layer.bias_mask, [-1]))

        posterior = tf.concat(posterior_masks_concat, axis=0)
        posterior = tf.reshape(tf.sigmoid(posterior), [-1])
        prior = tf.concat(prior_masks_concat, axis=0)
        prior = tf.reshape(tf.sigmoid(prior), [-1])

        # Define the optimization variable (new_posterior)
        new_posterior = cp.Variable(posterior.shape[0])

        kls = cp.kl_div(new_posterior,posterior) + cp.kl_div(1 - new_posterior, 1 - posterior)
        objective = cp.Minimize(cp.sum(kls))

        # Define the constraint using the KL divergence between new_posterior and prior
        kls_constraint = cp.kl_div(new_posterior, prior) + cp.kl_div(1 - new_posterior, 1 - prior)
        constraints = [cp.sum(kls_constraint) <= epsilon, new_posterior >= 0, new_posterior <= 1]

        # Formulate the problem
        problem = cp.Problem(objective, constraints)


        try:
            # Solve the problem
            result = problem.solve(solver=cp.ECOS)
        except:
            new_posterior = posterior
            print("Problem not solved, KL was ", kls_constraint.value)
            return tf.cast(posterior, dtype=np.float32)

        # Print the results
        print("Individual KL divergences:", kls_constraint.value)
        print("Optimal value (objective):", result)
        print("Optimal new_posterior:", new_posterior.value)
        print("Constraint value:", np.sum(cp.kl_div(new_posterior.value, prior).value + cp.kl_div(1 - new_posterior.value, 1 - prior).value))

        return tf.cast(new_posterior.value, dtype=np.float32)

    #@tf.function
    def project_onto_kl_ball(self, prior, posterior, epsilon, tol=1e-9, max_iter=1000, client_id=0, n_samples=1):
        compressed_model = {}
        gradient_updates = []

        prior = tf.reshape(prior, [-1])
        posterior = tf.reshape(posterior, [-1])

        # Define the optimization variable (new_posterior)
        new_posterior = cp.Variable(posterior.shape[0])

        kls = cp.kl_div(new_posterior,posterior) + cp.kl_div(1 - new_posterior, 1 - posterior)
        objective = cp.Minimize(cp.sum(kls))

        # Define the constraint using the KL divergence between new_posterior and prior
        kls_constraint = cp.kl_div(new_posterior, prior) + cp.kl_div(1 - new_posterior, 1 - prior)
        constraints = [cp.sum(kls_constraint) <= epsilon, new_posterior >= 0, new_posterior <= 1]

        # Formulate the problem
        problem = cp.Problem(objective, constraints)

        try:
            result = problem.solve(solver=cp.ECOS)
        except:
            new_posterior = posterior
            print("Problem not solved, KL was ", kls_constraint.value)
            return tf.cast(posterior, dtype=np.float32)

        return tf.cast(new_posterior.value, dtype=np.float32)

    def extract_params(self, prior_model, posterior_model, client_id=0):
        prior_masks_concat = []
        posterior_masks_concat = []

        if not self.mask_shapes_known:
            self.init(prior_model)
        # Iterate over both models' layers simultaneously
        for prior_layer, posterior_layer in zip(prior_model.layers, posterior_model.layers):
            # Check if the current layers have a 'mask' attribute
            if hasattr(prior_layer, 'mask'):
                prior_masks_concat.append(tf.reshape(prior_layer.mask, [-1]))
                posterior_masks_concat.append(tf.reshape(posterior_layer.mask, [-1]))
            if hasattr(prior_layer, 'bias_mask'):
                prior_masks_concat.append(tf.reshape(prior_layer.bias_mask, [-1]))
                posterior_masks_concat.append(tf.reshape(posterior_layer.bias_mask, [-1]))

        posterior_update = tf.concat(posterior_masks_concat, axis=0)

        if self.use_indiv_reference and client_id in self.reference.keys():
            prior = self.reference[client_id] * 0.5 + 0.5 * tf.concat(prior_masks_concat, axis=0)
        else:
            prior = tf.concat(prior_masks_concat, axis=0)
        if self.use_indiv_reference:
            self.reference[client_id] = tf.identity(posterior_update)

        return prior, posterior_update

    # @tf.function
    def compress(self, prior, posterior_update, client_id=0, n_samples=1, project_blocks=None, tf_models_provided=True):
        compressed_model = {}
        gradient_updates = []

        if tf_models_provided: prior, posterior_update = self.extract_params(prior, posterior_update, client_id=client_id)

        prior = tf.sigmoid(prior)
        posterior_update = tf.sigmoid(posterior_update)


        if not self.no_compress and self.update_blocks:
            # kls = tf.Variable(tf.zeros_like(tf.reshape(prior, [-1])))
            new_ids, kls = self.compute_indices(prior=tf.reshape(prior, [-1]), posterior=tf.reshape(posterior_update, [-1]),
                                        old_ids=None)
            if self.use_indices_immediately:
                self.ids = new_ids
        else:
            kls = self.compute_kls(prior=tf.reshape(prior, [-1]), posterior=tf.reshape(posterior_update, [-1]))
            new_ids = None

        block_kls = [0]
        block_sizes = [0]
        sample_list = list()
        if not self.no_compress:
            for i in range(self.num_indices):
                sampled_params, block_kls, block_sizes = self.rec(prior, posterior_update, self.ids, kls, project_blocks)
                sample_list.append(sampled_params)

        else:
            sample_list.append(posterior_update)
        kls = tf.reduce_sum(kls)

        return sample_list, kls, prior, posterior_update, self.ids, block_kls, block_sizes, new_ids

    # @tf.function
    def cut_to_prob(self, array, thres=1e-5): #tf.reduce_min(tf.boolean_mask(array, tf.not_equal(array, 0)))
        array = tf.where(array <= thres, thres, array)
        array = tf.where(array >= 1-thres, 1-thres, array)
        return array

    @tf.function
    def calculate_posterior(self, local_prm, server_prm, samples):
        posterior_prob = local_prm * samples + (1 - local_prm) * (1 - samples)
        prior_prob = server_prm * samples + (1 - server_prm) * (1 - samples)

        joint_posterior = tf.math.log(tf.cast(posterior_prob, tf.float64)) - tf.math.log(tf.cast(prior_prob, tf.float64))
        return joint_posterior

    @tf.function
    def calculate_posterior_bool(self, local_prm, server_prm, samples):
        posterior_prob = tf.where(samples, local_prm, 1 - local_prm)
        prior_prob = tf.where(samples, server_prm, 1 - server_prm)

        joint_posterior = tf.math.log(tf.cast(posterior_prob, tf.float64)) - tf.math.log(
            tf.cast(prior_prob, tf.float64))
        # tf.print("Likelihood ratios: ", tf.shape(joint_posterior), joint_posterior)
        return joint_posterior

    # @tf.function
    def rec(self, prior, posterior_update, ids, kls, project_blocks):
        start_i = 0


        start = time.time()
        tf.random.set_seed(np.random.randint(0, 10000000))

        num_bits = int(self.kl_rate if self.adaptive else np.log2(self.num_samples))
        num_samples = 2 ** num_bits

        block_kls = list()

        local_prm = posterior_update
        server_prm = prior
        server_prm = tf.reshape(server_prm, [-1, 1])
        local_prm = tf.reshape(local_prm, [-1, 1])

        sampled_params = tf.zeros_like(posterior_update)

        ragged_kls = tf.RaggedTensor.from_row_limits(values=kls, row_limits=ids)
        block_sizes = ragged_kls.row_lengths()[1:]
        block_kls = tf.reduce_sum(ragged_kls[1:], axis=1)

        # @tf.function
        def preprocess_and_stack(local_prm, server_prm, ids, max_output_size):
            segments = []
            for idx in range(len(ids) - 1):
                start_i, i = ids[idx], ids[idx + 1]
                padded_local = tf.pad(local_prm[start_i:i], [[0, max_output_size - (i-start_i)], [0, 0]], constant_values=1)
                padded_server = tf.pad(server_prm[start_i:i], [[0, max_output_size - (i-start_i)], [0, 0]], constant_values=1)
                stacked_segment = tf.stack([padded_local, padded_server], axis=0)
                segments.append(stacked_segment)

            stacked_matrix = tf.stack(segments, axis=0)
            return stacked_matrix

        #@tf.function
        def preprocess_and_stack_vectorized(local_prm, server_prm, ids, max_output_size):
            # Compute the segment lengths
            segment_lengths = block_sizes # ids[1:] - ids[:-1]

            # Create a boolean mask for padding based on segment lengths
            mask = tf.sequence_mask(segment_lengths, maxlen=max_output_size)

            # Gather all segments for local_prm and server_prm at once using the start and end indices
            local_segments = tf.RaggedTensor.from_row_lengths(local_prm, segment_lengths)
            server_segments = tf.RaggedTensor.from_row_lengths(server_prm, segment_lengths)

            # Convert ragged tensors to padded tensors
            padded_local = local_segments.to_tensor(default_value=1,
                                                    shape=(len(segment_lengths), max_output_size, local_prm.shape[-1]))
            padded_server = server_segments.to_tensor(default_value=1, shape=(
            len(segment_lengths), max_output_size, server_prm.shape[-1]))

            # Stack the padded local and server tensors along a new axis
            stacked_matrix = tf.stack([padded_local, padded_server], axis=1)

            return stacked_matrix

        max_output_size = tf.reduce_max(block_sizes)

        # @tf.function
        def process_segment_old(indices):
            start_i, i = indices[0], indices[1]
            samples = tf.cast(tf.random.uniform(shape=(tf.shape(local_prm[start_i: i])[0], num_samples)) < server_prm[start_i: i], dtype=tf.float32)
            joint_posterior_segment = self.calculate_posterior(local_prm[start_i: i], server_prm[start_i: i], samples)
            index = self.block_sample(joint_posterior_segment, num_samples)
            updates = tf.reshape(tf.gather(samples, indices=index, axis=1), [-1])

            return updates

        @tf.function
        def process_segment(stacked_segment):
            l = stacked_segment[0]
            s = stacked_segment[1]
            samples = tf.cast(tf.random.uniform(shape=(tf.shape(l)[0], num_samples)) < s, dtype=tf.bool)
            joint_posterior_segment = self.calculate_posterior_bool(l, s, samples)
            index = self.block_sample(joint_posterior_segment, num_samples)
            updates = tf.reshape(tf.gather(samples, indices=index, axis=1), [-1])
            return updates

        def remove_padding(stacked_arrays, block_sizes):
            # Create a function that removes padding based on each row's size
            def remove_padding_row(row_size):
                return row_size[:-1][:tf.cast(row_size[-1], tf.int8)]  # This works for individual slices, which are scalars

            result = list()
            for idx, array in enumerate(stacked_arrays):
                result.append(array[:block_sizes[idx]])

            return tf.concat(result, axis=0)

        # @tf.function
        def process_segment_projection(indices):
            start_i, i = indices[0], indices[1]
            posterior_update = self.project_onto_kl_ball(prior=tf.reshape(server_prm[start_i: i], [-1]),
                                                         posterior=tf.reshape(local_prm[start_i: i], [-1]),
                                                         epsilon=project_blocks, client_id=0)
            posterior_update = tf.reshape(posterior_update, [-1])
            samples = tf.cast(
                tf.random.uniform(shape=(tf.shape(local_prm[start_i: i])[0], num_samples)) < server_prm[start_i: i],
                dtype=tf.float32)
            joint_posterior_segment = self.calculate_posterior(posterior_update, server_prm[start_i: i], samples)

            index = self.block_sample(joint_posterior_segment, num_samples)
            updates = tf.reshape(tf.gather(samples, indices=index, axis=1), [-1])
            return updates  # tf.where(updates == 0, 0.01, 0.99)

        if project_blocks:
            sampled_params = tf.map_fn(process_segment_projection, tf.stack([ids[:-1], ids[1:]], axis=1), fn_output_signature=tf.RaggedTensorSpec(shape=[None], dtype=tf.float32), parallel_iterations=200)
        else:
            stacked_matrix = preprocess_and_stack_vectorized(local_prm, server_prm, ids, max_output_size)
            sampled_params = tf.vectorized_map(process_segment, stacked_matrix)
            sampled_params = remove_padding(sampled_params, block_sizes)

        sampled_params = tf.where(tf.cast(sampled_params, tf.float32) == 0, 0.01, 0.99)

        return sampled_params, block_kls, block_sizes

    @tf.function
    def block_sample(self, joint_posterior, num_samples):
        joint_posterior = tf.exp(tf.reduce_sum(joint_posterior, axis=0))
        joint_posterior = joint_posterior / tf.reduce_sum(joint_posterior)
        index = tf.random.categorical(tf.math.log(joint_posterior[None, :]), 1)[0, 0]
        index = tf.cond(
            tf.equal(index, joint_posterior.shape[0]),
            lambda: tf.random.uniform(shape=[], minval=0, maxval=joint_posterior.shape[0], dtype=tf.int64),  # Handle the case where index equals num_samples
            lambda: index  # Otherwise, use the index
        )

        return index

    @tf.function
    def compute_kls(self, prior, posterior):
        kl_posterior = posterior
        kl_prior = prior
        kls = kl_posterior * tf.math.log(kl_posterior / kl_prior) + (1 - kl_posterior) * tf.math.log(
            (1 - kl_posterior) / (1 - kl_prior))

        return kls

    @tf.function
    def compute_tvs(self, prior, posterior):
        tv_posterior = posterior
        tv_prior = prior
        tv_distance = 0.5 * tf.reduce_sum(tf.abs(tv_posterior - tv_prior))

        return tv_distance

    # @tf.function
    def compute_indices(self, prior, posterior, old_ids):
        kls = self.compute_kls(prior, posterior)

        if old_ids is not None:
            return old_ids, kls
        ids = [0]
        if not self.adaptive:
            step = tf.range(0, len(prior), self.block_size)
            ids = tf.unstack(step)
            if ids[-1] < len(prior):
                ids.append(len(prior))
        else:
            if not self.adaptive_avg:

                cumsum = np.cumsum(kls)
                while 1:
                    try:
                        i = min(ids[-1]+self.max_block_size, np.where(cumsum > self.kl_rate)[0][0])
                        cumsum -= cumsum[i]
                        ids.append(i)
                        if i % 100000 < 50: print(i)
                    except:
                        break
                if ids[-1] < len(prior):
                    ids.append(len(prior))
            else:
                n_blocks = tf.cast(tf.round(tf.reduce_sum(kls) / self.kl_rate), dtype=tf.int32)
                suggested_block_size = tf.cast(tf.round(len(prior) / n_blocks), dtype=tf.int32)
                step = tf.range(0, len(prior), suggested_block_size)
                ids = tf.unstack(step)
                if ids[-1] < len(prior):
                    ids.append(len(prior))

        return ids, kls

    def aggregate_ids(self, ids_s, balance=False):
        if not balance:
            sizes = []
            for r in ids_s:
                sizes.append(len(r))
            new_ids = []
            for i in range(max(sizes)):
                l = 0
                idx = 0
                for e in ids_s:
                    if len(e) > i:
                        idx += e[i]
                        l += 1
                idx = int(np.ceil(idx / l))
                if len(new_ids) > 0:
                    if idx > new_ids[-1]:
                        new_ids.append(idx)
                else:
                    new_ids.append(idx)
            return new_ids
        else:
            steps = int(round(np.mean([x[1] - x[0] for x in ids_s])))
            step = tf.range(0, ids_s[0][-1], steps)
            ids = tf.unstack(step)
            if ids[-1] < ids_s[0][-1]:
                ids.append(ids_s[0][-1])
            return ids


class FedSCRECCompressor(Compressor):
    def __init__(self, adaptive=False, kl_rate=1, num_samples=256, block_size=256, max_block_size=512, use_indiv_reference=False, num_indices=1, no_compress=False, adaptive_avg=False, use_indices_immediately=True):
        self.adaptive = adaptive
        self.adaptive_avg = adaptive_avg
        self.use_indices_immediately = use_indices_immediately
        self.kl_rate = kl_rate
        self.num_samples = num_samples
        self.block_size = block_size
        self.max_block_size = max_block_size
        self.update_blocks = True

        self.mask_shapes = []
        self.layer_names = []
        self.mask_shapes_known = False

        self.global_epoch = 0

        self.alphas = None
        self.betas = None

        self.reference = dict()
        self.use_indiv_reference = use_indiv_reference

        self.num_indices = num_indices
        self.no_compress = no_compress

        self.old_ids = None
        self.ids = None

    def log_histograms(self, prior, posterior_update):
        def numpy_conversion(prior, posterior_update):
            log_dict = {
                'epoch': self.global_epoch,
                "prior": wandb.Histogram(tf.sigmoid(prior).numpy()),
                'posteriors': wandb.Histogram(tf.sigmoid(posterior_update).numpy())
            }
            wandb.log(log_dict)
            return 0

        tf.py_function(func=numpy_conversion, inp=[prior, posterior_update], Tout=tf.float32)
        return

    # @tf.function
    def process(self, model_old, model_new, client_id = 0, project_blocks=None, tf_models_provided=True):
        compressed, kls, prior, posterior, ids, block_kls, block_sizes, new_ids = self.compress(model_old, model_new, client_id=client_id, project_blocks=project_blocks, tf_models_provided=tf_models_provided)
        return self.decompress(compressed), kls, prior, posterior, ids, block_kls, block_sizes, new_ids

    # @tf.function
    def update(self, aggregated_gradients, model):
        for var in model.trainable_variables:
            var.assign(aggregated_gradients[var.name])
        return aggregated_gradients
        #return 0

    # @tf.function
    def aggregate_and_update(self, sample_list, model, reset=True):
        sample_tensor = tf.stack(sample_list)
        sample_sum = tf.reduce_sum(sample_tensor, axis=0)

        if self.alphas == None or reset == True:
            self.alphas = 1
            self.betas = 1
        self.alphas += sample_sum
        self.betas += sample_tensor.shape[0] - sample_sum

        sample_sum = (self.alphas-1) / (self.alphas + self.betas - 2)

        print("Sum: ", sample_sum[:10])

        sample_sum = tf.sigmoid(sample_sum)

        print("Sigmoid sum: ", sample_sum[:10])

        print("Inverse sigmoid: ", inverse_sigmoid(sample_sum[:10]))

        aggregated_gradients = {}

        start_index = 0
        for var in model.trainable_variables:
            # Determine the shape and size of the current variable
            shape = var.shape
            num_elements = tf.reduce_prod(shape)
            end_index = start_index + num_elements

            # Extract and reshape the gradient for the current variable
            updated_gradient = tf.gather(sample_sum, tf.range(start_index, end_index))
            updated_gradient = tf.reshape(updated_gradient, shape)

            # Store the reshaped gradient
            aggregated_gradients[var.name] = updated_gradient

            # Update the start index for the next variable
            start_index = end_index

        # tf.print(start_index)
        self.update(aggregated_gradients, model)
        return 0

    def init(self, model):
        model_flattened = tf.concat([tf.reshape(var, [-1]) for var in model.trainable_variables], axis=0)
        step = tf.range(0, len(model_flattened), self.block_size)
        ids = tf.unstack(step)
        if ids[-1] < len(model_flattened) - 1:
            ids.append(len(model_flattened))
        self.ids = ids

    def compute_model_kls(self, prior_gradients, posterior_gradients, verbose=False):
        prior_concat = []
        posterior_concat = []

        # Iterate over both gradients simultaneously
        for prior_gradient, posterior_gradient in zip(prior_gradients, posterior_gradients):
            prior_concat.append(tf.reshape(prior_gradient, [-1]))
            posterior_concat.append(tf.reshape(posterior_gradient, [-1]))

        # Concatenate all gradients into single tensors
        prior = tf.concat(prior_concat, axis=0)
        posterior_update = tf.concat(posterior_concat, axis=0)

        if verbose:
            print(posterior_update[:10].numpy())
            print(prior[:10].numpy())

        # Compute KL divergence
        kls = self.compute_kls(prior=tf.reshape(prior, [-1]), posterior=tf.reshape(posterior_update, [-1]))

        return tf.reduce_sum(kls)


    def extract_params(self, prior_gradients, posterior_gradients, client_id=0):
        prior_gradients_concat = []
        posterior_gradients_concat = []

        for prior_gradient, posterior_gradient in zip(prior_gradients, posterior_gradients):
            prior_gradients_concat.append(tf.reshape(prior_gradient, [-1]))
            posterior_gradients_concat.append(tf.reshape(posterior_gradient, [-1]))

        # Concatenate all prior and posterior gradients
        prior_concat = tf.concat(prior_gradients_concat, axis=0)
        posterior_concat = tf.concat(posterior_gradients_concat, axis=0)

        if self.use_indiv_reference and client_id in self.reference.keys():
            prior = self.reference[client_id] * 0.5 + 0.5 * tf.concat(prior_concat, axis=0)
        else:
            prior = tf.concat(prior_concat, axis=0)
        if self.use_indiv_reference:
            self.reference[client_id] = tf.identity(posterior_concat)

        return prior, posterior_concat

    # @tf.function
    def compress(self, prior, posterior_update, client_id=0, n_samples=1, project_blocks=None, tf_models_provided=True):
        compressed_model = {}
        gradient_updates = []

        if tf_models_provided: prior, posterior_update = self.extract_params(prior, posterior_update, client_id=client_id)

        if not self.no_compress and self.update_blocks:
            # kls = tf.Variable(tf.zeros_like(tf.reshape(prior, [-1])))
            new_ids, kls = self.compute_indices(prior=tf.reshape(prior, [-1]), posterior=tf.reshape(posterior_update, [-1]),
                                        old_ids=None)
            if self.use_indices_immediately:
                self.ids = new_ids
        else:
            kls = self.compute_kls(prior=tf.reshape(prior, [-1]), posterior=tf.reshape(posterior_update, [-1]))
            new_ids = None

        block_kls = [0]
        block_sizes = [0]
        sample_list = list()
        if not self.no_compress:
            for i in range(self.num_indices):
                sampled_params, block_kls, block_sizes = self.rec(prior, posterior_update, self.ids, kls, project_blocks)
                sample_list.append(sampled_params)
        else:
            sample_list.append(posterior_update)
        kls = tf.reduce_sum(kls)

        return sample_list, kls, prior, posterior_update, self.ids, block_kls, block_sizes, new_ids

    # @tf.function
    def cut_to_prob(self, array, thres=1e-5): #tf.reduce_min(tf.boolean_mask(array, tf.not_equal(array, 0)))
        array = tf.where(array <= thres, thres, array)
        array = tf.where(array >= 1-thres, 1-thres, array)
        return array

    @tf.function
    def calculate_posterior(self, local_prm, server_prm, samples):
        posterior_prob = local_prm * samples + (1 - local_prm) * (1 - samples)
        prior_prob = server_prm * samples + (1 - server_prm) * (1 - samples)

        joint_posterior = tf.math.log(tf.cast(posterior_prob, tf.float64)) - tf.math.log(tf.cast(prior_prob, tf.float64))
        return joint_posterior

    @tf.function
    def calculate_posterior_bool(self, local_prm, server_prm, samples):
        posterior_prob = tf.where(samples, local_prm, 1 - local_prm)
        prior_prob = tf.where(samples, server_prm, 1 - server_prm)

        joint_posterior = tf.math.log(tf.cast(posterior_prob, tf.float64)) - tf.math.log(
            tf.cast(prior_prob, tf.float64))
        return joint_posterior

    # @tf.function
    def rec(self, prior, posterior_update, ids, kls, project_blocks):
        start_i = 0

        start = time.time()
        tf.random.set_seed(np.random.randint(0, 10000000))

        num_bits = int(self.kl_rate if self.adaptive else np.log2(self.num_samples))
        num_samples = 2 ** num_bits

        block_kls = list()

        local_prm = posterior_update
        server_prm = prior
        server_prm = tf.reshape(server_prm, [-1, 1])
        local_prm = tf.reshape(local_prm, [-1, 1])

        sampled_params = tf.zeros_like(posterior_update)

        ragged_kls = tf.RaggedTensor.from_row_limits(values=kls, row_limits=ids)
        block_sizes = ragged_kls.row_lengths()[1:]
        block_kls = tf.reduce_sum(ragged_kls[1:], axis=1)

        # @tf.function
        def preprocess_and_stack(local_prm, server_prm, ids, max_output_size):
            segments = []
            for idx in range(len(ids) - 1):
                start_i, i = ids[idx], ids[idx + 1]
                padded_local = tf.pad(local_prm[start_i:i], [[0, max_output_size - (i-start_i)], [0, 0]], constant_values=1)
                padded_server = tf.pad(server_prm[start_i:i], [[0, max_output_size - (i-start_i)], [0, 0]], constant_values=1)
                stacked_segment = tf.stack([padded_local, padded_server], axis=0)
                segments.append(stacked_segment)

            stacked_matrix = tf.stack(segments, axis=0)
            return stacked_matrix

        #@tf.function
        def preprocess_and_stack_vectorized(local_prm, server_prm, ids, max_output_size):
            # Compute the segment lengths
            segment_lengths = block_sizes # ids[1:] - ids[:-1]

            # Create a boolean mask for padding based on segment lengths
            mask = tf.sequence_mask(segment_lengths, maxlen=max_output_size)

            # Gather all segments for local_prm and server_prm at once using the start and end indices
            local_segments = tf.RaggedTensor.from_row_lengths(local_prm, segment_lengths)
            server_segments = tf.RaggedTensor.from_row_lengths(server_prm, segment_lengths)

            # Convert ragged tensors to padded tensors
            padded_local = local_segments.to_tensor(default_value=1,
                                                    shape=(len(segment_lengths), max_output_size, local_prm.shape[-1]))
            padded_server = server_segments.to_tensor(default_value=1, shape=(
            len(segment_lengths), max_output_size, server_prm.shape[-1]))

            # Stack the padded local and server tensors along a new axis
            stacked_matrix = tf.stack([padded_local, padded_server], axis=1)

            return stacked_matrix

        max_output_size = tf.reduce_max(block_sizes)

        # @tf.function
        def process_segment_old(indices):
            start_i, i = indices[0], indices[1]
            samples = tf.cast(tf.random.uniform(shape=(tf.shape(local_prm[start_i: i])[0], num_samples)) < server_prm[start_i: i], dtype=tf.float32)
            joint_posterior_segment = self.calculate_posterior(local_prm[start_i: i], server_prm[start_i: i], samples)
            index = self.block_sample(joint_posterior_segment, num_samples)
            updates = tf.reshape(tf.gather(samples, indices=index, axis=1), [-1])
            return updates # tf.where(updates == 0, 0.01, 0.99)

        @tf.function
        def process_segment(stacked_segment):
            l = stacked_segment[0]
            s = stacked_segment[1]
            samples = tf.cast(tf.random.uniform(shape=(tf.shape(l)[0], num_samples)) < s, dtype=tf.bool)
            joint_posterior_segment = self.calculate_posterior_bool(l, s, samples)
            index = self.block_sample(joint_posterior_segment, num_samples)
            updates = tf.reshape(tf.gather(samples, indices=index, axis=1), [-1])
            return updates

        def remove_padding(stacked_arrays, block_sizes):
            # Create a function that removes padding based on each row's size
            def remove_padding_row(row_size):
                return row_size[:-1][:tf.cast(row_size[-1], tf.int8)]  # This works for individual slices, which are scalars

            result = list()
            for idx, array in enumerate(stacked_arrays):
                result.append(array[:block_sizes[idx]])

            return tf.concat(result, axis=0)

        # @tf.function
        def process_segment_projection(indices):
            start_i, i = indices[0], indices[1]
            posterior_update = self.project_onto_kl_ball(prior=tf.reshape(server_prm[start_i: i], [-1]),
                                                         posterior=tf.reshape(local_prm[start_i: i], [-1]),
                                                         epsilon=project_blocks, client_id=0)
            posterior_update = tf.reshape(posterior_update, [-1])
            samples = tf.cast(
                tf.random.uniform(shape=(tf.shape(local_prm[start_i: i])[0], num_samples)) < server_prm[start_i: i],
                dtype=tf.float32)
            joint_posterior_segment = self.calculate_posterior(posterior_update, server_prm[start_i: i], samples)

            index = self.block_sample(joint_posterior_segment, num_samples)
            updates = tf.reshape(tf.gather(samples, indices=index, axis=1), [-1])
            return updates

        if project_blocks:
            sampled_params = tf.map_fn(process_segment_projection, tf.stack([ids[:-1], ids[1:]], axis=1), fn_output_signature=tf.RaggedTensorSpec(shape=[None], dtype=tf.float32), parallel_iterations=200)
        else:
            stacked_matrix = preprocess_and_stack_vectorized(local_prm, server_prm, ids, max_output_size)
            sampled_params = tf.vectorized_map(process_segment, stacked_matrix)
            sampled_params = remove_padding(sampled_params, block_sizes)

        sampled_params = tf.where(tf.cast(sampled_params, tf.float32) == 0, -0.99, 0.99)

        return sampled_params, block_kls, block_sizes

    @tf.function
    def block_sample(self, joint_posterior, num_samples):
        joint_posterior = tf.exp(tf.reduce_sum(joint_posterior, axis=0))
        joint_posterior = joint_posterior / tf.reduce_sum(joint_posterior)
        index = tf.random.categorical(tf.math.log(joint_posterior[None, :]), 1)[0, 0]
        index = tf.cond(
            tf.equal(index, joint_posterior.shape[0]),
            lambda: tf.random.uniform(shape=[], minval=0, maxval=joint_posterior.shape[0], dtype=tf.int64),  # Handle the case where index equals num_samples
            lambda: index  # Otherwise, use the index
        )

        return index

    @tf.function
    def compute_kls(self, prior, posterior):
        kl_posterior = posterior #  self.cut_to_prob(posterior)
        kl_prior = prior # self.cut_to_prob(prior)
        kls = kl_posterior * tf.math.log(kl_posterior / kl_prior) + (1 - kl_posterior) * tf.math.log(
            (1 - kl_posterior) / (1 - kl_prior))

        return kls

    @tf.function
    def compute_tvs(self, prior, posterior):
        tv_posterior = posterior  # self.cut_to_prob(posterior) if necessary
        tv_prior = prior  # self.cut_to_prob(prior) if necessary
        tv_distance = 0.5 * tf.reduce_sum(tf.abs(tv_posterior - tv_prior))

        return tv_distance

    # @tf.function
    def compute_indices(self, prior, posterior, old_ids):
        print("compute kls")
        kls = self.compute_kls(prior, posterior)
        if old_ids is not None:
            return old_ids, kls
        ids = [0]
        if not self.adaptive:
            step = tf.range(0, len(prior), self.block_size)
            ids = tf.unstack(step)
            if ids[-1] < len(prior):
                ids.append(len(prior))
        else:
            if not self.adaptive_avg:
                cumsum = np.cumsum(kls)
                while 1:
                    try:
                        i = min(ids[-1]+self.max_block_size, np.where(cumsum > self.kl_rate)[0][0])
                        cumsum -= cumsum[i]
                        ids.append(i)
                        if i % 100000 < 50: print(i)
                    except:
                        break
                if ids[-1] < len(prior):
                    ids.append(len(prior))
            else:
                n_blocks = tf.cast(tf.round(tf.reduce_sum(kls) / self.kl_rate), dtype=tf.int32)
                suggested_block_size = tf.cast(tf.round(len(prior) / n_blocks), dtype=tf.int32)
                step = tf.range(0, len(prior), suggested_block_size)
                ids = tf.unstack(step)
                if ids[-1] < len(prior):
                    ids.append(len(prior))

        return ids, kls

    def aggregate_ids(self, ids_s, balance=False):
        if not balance:
            sizes = []
            for r in ids_s:
                sizes.append(len(r))
            new_ids = []
            for i in range(max(sizes)):
                l = 0
                idx = 0
                for e in ids_s:
                    if len(e) > i:
                        idx += e[i]
                        l += 1
                idx = int(np.ceil(idx / l))
                if len(new_ids) > 0:
                    if idx > new_ids[-1]:
                        new_ids.append(idx)
                else:
                    new_ids.append(idx)
            return new_ids
        else:
            steps = int(round(np.mean([x[1] - x[0] for x in ids_s])))
            step = tf.range(0, ids_s[0][-1], steps)
            ids = tf.unstack(step)
            if ids[-1] < ids_s[0][-1]:
                ids.append(ids_s[0][-1])
            return ids
