from functools import partial

import jax
import jax.numpy as jnp  # JAX NumPy

@jax.jit
def sparsify(array, percentage):
    abs_array = jnp.abs(array)
    threshold = jnp.sort(abs_array, axis=None)[(len(abs_array.flatten()) * percentage).astype(int)]
    return jnp.where(abs_array < threshold, 0, array)

@partial(jax.jit, static_argnums=(2,))
def sparsify_index(array, percentage, size):
    abs_array = jnp.abs(array)
    threshold = jnp.sort(abs_array, axis=None)[(len(abs_array.flatten()) * percentage).astype(int)]
    return jnp.where(abs_array < threshold, size=size)[0]
# Function to quantize arrays
@jax.jit
def quantize(array, n_bits):
    max_val, min_val = array.max(), array.min()
    step = (max_val - min_val) / (2 ** n_bits - 1)
    return ((array - min_val) / step).round()

# Function to dequantize arrays
@jax.jit
def dequantize(array, min_val, max_val, n_bits):
    step = (max_val - min_val) / (2 ** n_bits - 1)
    return array * step + min_val

# l2 distance
@jax.jit
def l2(x, y):
    return -1 * jnp.sqrt(jnp.sum((x - y) ** 2))  # / jnp.sqrt(jnp.sum(x ** 2))

@jax.jit
def get_diff(x, y):
    return x - y


def name_me(args):
    name = ''
    if args.evofed:
        name += 'E'
    if args.linear:
        name += 'L'
    if args.sparsify:
        name += 'S'
    if args.quantize:
        name += 'Q'
    if args.noise:
        name += 'N'
    if name == '':
        name = 'BP'
    if args.parts > 1:
        name += '-P{}'.format(args.parts)
    if args.n_clients > 1:
        name += '-C{}'.format(args.n_clients)

    return name


def sum_and_average(arr, percentage):
    # Calculate the total sum of the array
    total_sum = jnp.sum(jnp.abs(arr))

    # Calculate the target sum which is 90% of the total sum
    target_sum = percentage * total_sum

    # Sort the array in descending order
    sorted_arr = jnp.sort(jnp.abs(arr))[::-1]

    # Initialize the cumulative sum and count
    cumulative_sum = 0.0
    count = 0

    # Iterate over the sorted array to find the count and sum
    for i in range(len(sorted_arr)):
        cumulative_sum += sorted_arr[i]
        count += 1
        if cumulative_sum >= target_sum:
            break

    # Calculate the average of the numbers that sum up to 90% of the total
    average_value = cumulative_sum / count

    return count, average_value