from clustering import Coordinator, Worker, DSPGD, distributed_linear_kmeans
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from numpy import linalg as LA
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.cluster import KMeans
import argparse
from mpi4py import MPI

def parse_args():
    parser = argparse.ArgumentParser(description='multiprocess simulation')
    parser.add_argument('--n_clusters', default=2, type=int, help='number of clusters')
    parser.add_argument('--n_components', default=2, type=int, help='number of components')
    parser.add_argument('--n_nodes', default=2, type=int, help='number of nodes')
    parser.add_argument('--gamma', default=0.5, type=float, help='parameter gamma of rbf kernel')
    parser.add_argument('--tol', default=1e-2, type=float, help='tolerance')
    parser.add_argument('--dataset', default='mushroom', type=str, help='dataset')
    # parser.add_argument('--seed', default=20, type=int, help='seed of random number generator')
    parser.add_argument('--n_points', default=500, type=int, help='the number of sampled points of coreset')
    parser.add_argument('--index', default=0, type=int, help='experiment index')
    parser.add_argument('--n_dim', default=10, type=int, help='dimensions of output')
    parser.add_argument('--n_iterations', default=10, type=int, help='itertion')

    return parser.parse_args()

def load_data(dataset, rank):
    if dataset == 'mushroom':
        feature_path = './datasets/mushroom/mushroom_feature' + str(rank) + '.npy'
        label_path = './datasets/mushroom/mushroom_label' + str(rank) + '.npy'
        x = np.load(feature_path)
        y = np.load(label_path)
    elif dataset == 'mnist':
        feature_path = './datasets/mnist/mnist_feature' + str(rank) + '.npy'
        label_path = './datasets/mnist/mnist_label' + str(rank) + '.npy'
        x = np.load(feature_path)
        y = np.load(label_path)
    elif dataset == 'covtype':
        feature_path = './datasets/covtype/covtype_feature' + str(rank) + '.npy'
        label_path = './datasets/covtype/covtype_label' + str(rank) + '.npy'
        x = np.load(feature_path)
        y = np.load(label_path)
    else:
        raise Exception('Dataset number error', dataset)
    return x, y


if __name__ == '__main__':
    args = parse_args()
    n_clusters = args.n_clusters
    n_components = args.n_components
    n_nodes = args.n_nodes
    gamma = args.gamma
    tol = args.tol
    dataset = args.dataset
    n_points = args.n_points
    index = args.index
    dim = args.n_dim
    n_iterations = args.n_iterations

    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    stop_flag = False

    if dataset == 'mushroom':
        select_dim = 4
    elif dataset == 'covtype':
        select_dim = 10
    elif dataset == 'mnist':
        select_dim = 12
    elif dataset == 'huawei':
        select_dim = 10
    else:
        raise Exception

    if rank == 0: # rank 0 is the cloud server
        generator = np.random.RandomState(0)
        init_vec = generator.normal(size=2*n_components)
        weight_seed = np.random.randint(65536, size=1)
        bias_seed = np.random.randint(65536, size=1)
        print('weight_seed:', weight_seed)
        print('bias_seed:', bias_seed)

        lowrank_set = []

    else:
        init_vec = None
        weight_seed = None
        bias_seed = None
        low_rank = None
        lowrank_set = None
        lowrank_components = None

    init_vec = comm.bcast(init_vec, root=0)
    weight_seed = comm.bcast(weight_seed, root=0)
    bias_seed = comm.bcast(bias_seed, root=0)

    if rank == 0:
        process = Coordinator(n_nodes=n_nodes, max_rank=n_components)
        feature = None
        label = None
    else:
        process = Worker(rank=rank, n_components=n_components, gamma=gamma,
                         weight_seed=weight_seed, bias_seed=bias_seed)
        feature, label = load_data(dataset, rank)

    optimizor = DSPGD(n_clusters, n_components, gamma, rank)

    ite = 1
    tol = 1e-2
    eta = 1.0
    eigvals, local_eigvecs, comm_cost = optimizor.DSPGD_update_CEM(process, feature, ite, dim, comm, tol)

    lamb = 0.9 * eigvals[-1]
    ite = ite + 1


    if rank != 0:
        local_eigvecs_old = np.copy(local_eigvecs)
        comm_cost_wCEM = None
        comm_iter_wCEM = None
        matrix_rank = None
    else:
        local_eigvecs_old = None
        comm_cost_wCEM = []
        comm_iter_wCEM = []
        matrix_rank = []

        # comm_cost_wCEM.append(comm_cost[1] * n_nodes)
        # comm_iter_wCEM.append(comm_cost[0])
        # matrix_rank.append(eigvals.shape[0])


    label_set = comm.gather(label, root=0)
    if rank == 0:
        true_label = process.concatenation(label_set[1:])
        NMI_set = []
    else:
        true_label = None
        NMI_set = None

    NMI_socre_set = []
    # extra_comm_cost_set = []
    cluster_feature_set = []

    for i in range(1, n_iterations):
        eta = 1./ite
        if rank == 0:
            vec_size = eigvals.shape[0] + n_components
        else:
            vec_size = None

        eigvals, local_eigvecs, comm_cost = optimizor.DSPGD_update_CEM(process, feature, ite, dim, comm, tol,
                                                                       eta * lamb, eigvals, local_eigvecs, vec_size)
        ite = ite + 1

        # if rank == 0:
        #     comm_cost_wCEM.append(comm_cost[1] * n_nodes)
        #     comm_iter_wCEM.append(comm_cost[0])
        #     matrix_rank.append(eigvals.shape[0])

    if rank != 0:
        cluster_feature = np.matmul(local_eigvecs[:, :select_dim],
                                    np.diag(np.sqrt(eigvals[:select_dim] + (1 - eta) * lamb)))
    else:
        cluster_feature = None

    centroids, n_upload_points = distributed_linear_kmeans(n_clusters, n_points, cluster_feature, comm, rank)

    if rank != 0:
        distanceMatrix = euclidean_distances(cluster_feature, centroids)
        cluster_idx = np.argmin(distanceMatrix, axis=1)
    else:
        cluster_idx = None
    pred_set = comm.gather(cluster_idx, root=0)
    if rank == 0:
        pred = np.copy(pred_set[1])
        for vec in pred_set[2:]:
            pred = np.concatenate((pred, vec))

        NMI_socre = normalized_mutual_info_score(true_label, pred, 'arithmetic')
        print('NMI score:', NMI_socre)
    else:
        pred = None
















