import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances
from scipy import stats
from mpi4py import MPI

class DKmeans(object):
    def __init__(self, n_clusters):
        self.n_clusters = n_clusters

    def distributed_coreset(self, data, n_points, comm, rank):
        # round one
        if rank != 0:
            local_kmeans = self.clustering_alg(data)
            centers = local_kmeans.cluster_centers_
            local_costs = self.get_cost(data, centers)
        else:
            local_kmeans = None
            centers = None
            local_costs = None

        cost_list = comm.gather(local_costs, root=0)
        if rank == 0:
            total_cost = 0
            for local_cost in cost_list[1:]:
                total_cost = total_cost + np.sum(local_cost)
        else:
            total_cost = None

        total_cost = comm.bcast(total_cost, root=0)

        # round two
        if rank != 0:
            local_cost_sum = np.sum(local_costs)
            local_n_points = int(n_points * local_cost_sum / total_cost)
            sample_prob = local_costs / local_cost_sum

            local_n_samples = local_costs.shape[0]
            sample_index = np.arange(local_n_samples)

            custm_pdf = stats.rv_discrete(values=(sample_index, sample_prob))

            rand_num = custm_pdf.rvs(size=local_n_points)

            sample_points = data[rand_num]

            coreset = np.concatenate((centers, sample_points), axis=0)

            weight_p = total_cost / (n_points * local_costs[rand_num])
            pred_labels = local_kmeans.predict(data)

            weight_b = np.zeros(self.n_clusters)
            for i in range(self.n_clusters):
                pb_size = np.sum(pred_labels == i)
                data_index = np.argwhere(pred_labels == i).reshape(-1)

                if data_index.size == 0:
                    weight_b[i] = pb_size
                else:
                    inter_index = np.intersect1d(data_index, rand_num)
                    weight_b[i] = pb_size - np.sum(total_cost / (n_points * local_costs[inter_index]))

            weights = np.concatenate((weight_b, weight_p))
        else:
            coreset = None
            weights = None

        return coreset, weights

    def clustering_alg(self, data, weights=None):
        kmeans_model = KMeans(n_clusters=self.n_clusters)
        kmeans_model.fit(data, sample_weight=weights)

        return kmeans_model

    def get_cost(self, data, centers):
        distanceMatrix = euclidean_distances(data, centers)

        return np.min(distanceMatrix, axis=1)

def distributed_linear_kmeans(n_clusters, n_points, feature, comm, rank):
    kmeans_model = DKmeans(n_clusters)

    coreset, weights = kmeans_model.distributed_coreset(feature, n_points, comm, rank)

    coreset_list = comm.gather(coreset, root=0)
    weight_list = comm.gather(weights, root=0)

    if rank == 0:
        central_coreset = np.copy(coreset_list[1])
        for local_coreset in coreset_list[2:]:
            central_coreset = np.concatenate((central_coreset, local_coreset), axis=0)
        central_weights = np.copy(weight_list[1])
        for local_weights in weight_list[2:]:
            central_weights = np.concatenate((central_weights, local_weights))
        central_kmeans = kmeans_model.clustering_alg(central_coreset, central_weights)

        n_upload_points = central_weights.shape[0]
    else:
        central_kmeans = None
        n_upload_points = None

    central_kmeans = comm.bcast(central_kmeans, root=0)

    return central_kmeans.cluster_centers_, n_upload_points