#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from scipy.stats import bernoulli


def sampling_clients(sampling: str, n_clients: int, n_sampled: int, weights: np.array):
    """Return an array with the indices of the clients that are sampled and
    an array with their associated weights"""

    #    np.random.seed(i)

    if sampling == "Full":

        selected = np.ones(n_clients)
        agg_weights = weights

    if sampling == "MD":

        selected = np.zeros(n_clients)
        agg_weights = np.zeros(n_clients)

        selected_idx = np.random.choice(
            n_clients, size=n_sampled, replace=True, p=weights
        )

        for idx in selected_idx:
            selected[idx] = 1
            agg_weights[idx] += 1 / n_sampled

    if sampling == "Improved":

        selected = np.zeros(n_clients)
        agg_weights = np.zeros(n_clients)

        while np.sum(selected) < n_sampled:
            selected_idx = np.random.choice(n_clients, size=1, p=weights)
            selected[selected_idx] = 1
            agg_weights[selected_idx] += 1

        agg_weights /= np.sum(agg_weights)

    elif sampling == "Uniform":

        selected = np.zeros(n_clients)
        agg_weights = np.zeros(n_clients)

        selected_idx = np.random.choice(
            n_clients, size=n_sampled, replace=False)

        for idx in selected_idx:
            selected[idx] = 1
            agg_weights[idx] = n_clients / n_sampled * weights[idx]

    elif sampling == "Binomial":

        p_sampling = n_sampled / n_clients
        selected = bernoulli.rvs(p_sampling, size=n_clients)

        agg_weights = np.multiply(selected, weights) / p_sampling

    elif sampling == "Poisson":

        selected = np.array([bernoulli.rvs(n_sampled * pi) for pi in weights])
        agg_weights = selected / n_sampled

    return selected, agg_weights



import numpy as np
from numpy.random import choice


def get_clusters_with_alg1(n_sampled: int, weights: np.array):
    "Algorithm 1"

    epsilon = int(10 ** 10)
    # associate each client to a cluster
    augmented_weights = np.array([w * n_sampled * epsilon for w in weights])
    ordered_client_idx = np.flip(np.argsort(augmented_weights))

    n_clients = len(weights)
    distri_clusters = np.zeros((n_sampled, n_clients)).astype(int)

    k = 0
    for client_idx in ordered_client_idx:

        while augmented_weights[client_idx] > 0:

            sum_proba_in_k = np.sum(distri_clusters[k])

            u_i = min(epsilon - sum_proba_in_k, augmented_weights[client_idx])

            distri_clusters[k, client_idx] = u_i
            augmented_weights[client_idx] += -u_i

            sum_proba_in_k = np.sum(distri_clusters[k])
            if sum_proba_in_k == 1 * epsilon:
                k += 1

    distri_clusters = distri_clusters.astype(float)
    for l in range(n_sampled):
        distri_clusters[l] /= np.sum(distri_clusters[l])

    return distri_clusters


def sampling_clients_cluster(distri_clusters:np.array):

    n_clients = len(distri_clusters[0])
    n_sampled = len(distri_clusters)

    selected = np.zeros(n_clients)
    agg_weights = np.zeros(n_clients)

    # sampled_clients = np.zeros(len(distri_clusters), dtype=int)

    for k in range(n_sampled):
        idx_selected = int(choice(n_clients, 1, p=distri_clusters[k]))
        selected[idx_selected] = 1
        agg_weights[idx_selected] += 1/n_sampled


        # sampled_clients[k] = int(choice(n_clients, 1, p=distri_clusters[k]))

    return selected, agg_weights
