import numpy as np
import random
import math
import time
from sklearn.neighbors import KDTree
from sklearn.neighbors import BallTree
from sklearn.datasets import make_blobs
from itertools import combinations
from functools import partial
from multiprocessing.pool import Pool
from multiprocessing import RawArray
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.cluster import MiniBatchKMeans
import pandas as pd
from sklearn.metrics.pairwise import euclidean_distances

# data, _ =make_blobs(n_samples=500000,n_features=2,cluster_std=0.1)

# read data
MAX_PROCESSOR = 20
var_dict = {}

r_tot_c = 0


# print(data.shape, W.sum())


def f_swap(args):
    if (args[0] == "SWAP"):
        return SWAP_ORG(*args[1:])
    if (args[0] == "SEARCH"):
        return NEIGHBOR_SEARCH(*args[1:])


def SWAP_ORG(pair, init):
    id_out = pair[0]
    id_in = pair[1]
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    w = np.frombuffer(var_dict['w']).reshape(var_dict['w_shape'])
    init_new = init.copy()
    init_new[id_out] = id_in
    init_new_np = data1[init_new]
    init_new_np = np.array(init_new_np.copy())
    Tree1 = BallTree(init_new_np, leaf_size=40)
    dist1, _ = Tree1.query(data1, k=1)
    dist1 = dist1[:, 0] ** 2
    cost1 = (dist1 * w).sum()
    return cost1, id_out, id_in


def fast_local_search(data, init, ybxl, r, t, k, W):
    print("-----------------------------------Start Original Local Search------------------------------------")
    data_shape = (data.shape[0], data.shape[1])
    data_g = RawArray('d', data_shape[0] * data_shape[1])
    w_shape = (W.shape[0])
    w_g = RawArray('d', w_shape)
    w_np = np.frombuffer(w_g).reshape(w_shape)
    data_np = np.frombuffer(data_g).reshape(data_shape)
    np.copyto(data_np, data)
    np.copyto(w_np, W)
    pool1 = Pool(processes=MAX_PROCESSOR, initializer=init_worker, initargs=(data_g, data_shape, w_g, w_shape))

    inf = 1000000000000000000000000000000000000000005
    init_np = data.copy()[init]
    Tree = BallTree(init_np.copy(), leaf_size=40)
    dist, ind = Tree.query(data.copy(), k=t)
    ind = np.array(ind, dtype=int)
    dist = dist[:, 0] ** 2
    # dist = dist * W
    # print("Check",dist.shape)
    nearest = []
    cost_glob = inf

    # print(second_nearest)
    for i in range(0, k):
        id_i = np.argwhere(ind[:, 0] == i)
        id_i = id_i[:, 0]
        nearest.append(id_i)
    prob = (dist.copy() * W) / ((dist.copy() * W).sum())
    cost_now = (dist * W).sum()
    id_range = [i for i in range(0, data.shape[0])]
    fail = 0
    while (1):
        # print("START Regular Local Search")
        for i in range(0, r):
            nextpoint = np.random.choice(id_range, 1, replace=False, p=prob)[0]
            Min = inf
            # for j in range(0,k):
            #     #construct new clutering centers
            #     init_new = init.copy()
            #     init_new[j] = nextpoint
            #     #calculate the new distances
            #     id_j = nearest[j]
            #     # cost_delta = 0
            #     dist_s = dist.copy()
            #     for p in range(0,len(id_j)):
            #         dist_new = ((data[id_j[p]] - data[nextpoint]) ** 2).sum()
            #         if(dist_new<second_nearest[id_j[p]]):
            #             # cost_delta += dist_new
            #             dist_s[id_j[p]] = dist_new
            #         else:
            #             dist_s[id_j[p]] = second_nearest[id_j[p]]
            #     # cost_org = dist[id_j].sum()
            #     # cost_new = cost_now - cost_org + cost_delta
            #     cost_new = (dist_s.copy() * W).sum()
            #     if(cost_new < Min):
            #         Min = cost_new
            #         init_f = init_new.copy()
            pairs = [[j, nextpoint] for j in range(0, k)]
            r_tot = list(pool1.imap(f_swap, [("SWAP", pairs[j], init.copy()) for j in range(0, len(pairs))]))
            for j in range(0, len(r_tot)):
                if (r_tot[j][0] < Min):
                    Min = r_tot[j][0]
                    id_out = r_tot[j][1]
                    id_in = r_tot[j][2]

            # print(Min, cost_now)

            if (Min < cost_now):
                init[id_out] = id_in
                init_np = np.array(data[init])
                Tree2 = BallTree(init_np.copy(), leaf_size=40)
                dist, _ = Tree2.query(data.copy(), k=1)
                dist = dist[:, 0] ** 2
                cost_now = (dist * W).sum()
                prob = (dist.copy() * W) / ((dist.copy() * W).sum())
                # print("Round", i, "Has A Swap",cost_now)

        # print("Enter the Lloyd")
        init_L = (data.copy())[init]
        TreeL = KDTree(init_L.copy(), leaf_size=40)
        _, indL = TreeL.query(data.copy(), k=1)
        indL = indL[:, 0]
        while (1):
            # Create new centers
            init_L_new = init_L.copy()
            for i1 in range(0, k):
                id_i = np.argwhere(indL == i1)
                id_i = id_i[:, 0]
                W_repeat = np.tile(W[id_i], (data.shape[1], 1))
                W_repeat = W_repeat.transpose()
                init_L_new[i1] = np.sum(data[id_i] * W_repeat, axis=0) / W[id_i].sum()
                # print("Check",init)
            # Find the nearest data points to approximate new centers
            # print("CHECK",init_L_new)
            TreeL1 = KDTree(data.copy(), leaf_size=40)
            _, indL1 = TreeL1.query(init_L_new, k=1)
            # print("CHECK",init_L_new.copy())
            indL1 = indL1[:, 0]
            # print("Check For New Centers", indL1)
            init_temp = data[indL1]
            # Recalculate the cost
            TreeL = KDTree(init_temp.copy(), leaf_size=40)
            distL, indL = TreeL.query(data.copy(), k=1)
            indL = indL[:, 0]
            distL = distL[:, 0] ** 2
            distL = distL * W
            if (distL.sum() < cost_now):
                cost_now = distL.sum()
                init = indL1.copy()
                init_L = init_temp.copy()
                # print("Lloyd Has A Swap", cost_now)
            else:
                # print("No Better",distL.sum())
                break

        ct = 0
        # print("Enter the Nearest Neighbot Search")
        # print("Check For Init", init)
        init_new_np = data[init.copy()]
        k_list = [10, 20, 30, 50, 100]
        k_id = 0
        dist_f = 0
        while (1):
            if (k_id >= len(k_list)):
                break
            Tree2 = KDTree(data.copy(), leaf_size=40)
            _, ind2 = Tree2.query(init_new_np, k=k_list[k_id])
            # print(ind2)
            Min = 10000000000000000000000000000000005
            init_f = 0
            id_in = -1
            id_out = -1
            pair = [[i2, ind2[i2][j]] for i2 in range(0, k) for j in range(0, ind2.shape[1])]
            r_tot = list(pool1.imap(f_swap, [("SEARCH", pair[b], init.copy()) for b in range(0, len(pair))]))
            global r_tot_c
            r_tot_c = r_tot.copy()
            for i2 in range(0, len(r_tot)):
                if (r_tot[i2][0] < Min):
                    id_in = r_tot[i2][2]
                    id_out = r_tot[i2][1]
                    Min = r_tot[i2][0]
            init_f_id = init.copy()
            init_f_id[id_out] = id_in
            init_f = data[init_f_id]
            if (Min < cost_now):
                cost_now = Min
                init_new_np = init_f.copy()
                init = init_f_id.copy()
                # print("Neighbor Search Success",cost_now)
            else:
                k_id += 1

            # for i in range(0,k):
            #     for j in range(0,ind2.shape[1]):
            #         init_new_np1 = init_new_np.copy()
            #         init_new_np1[i] = data[ind2[i][j]]
            #         Tree3 = KDTree(init_new_np1, leaf_size=40)
            #         dist3, _ = Tree3.query(data,k=1)
            #         dist3 = dist3[:,0] ** 2
            #         #dist3 = dist3 * W
            #         cost_ff = (dist3 * W).sum()
            #         if(cost_ff<Min):
            #             Min = cost_ff
            #             init_f = init.copy()
            #             init_f[i] = ind2[i][j]
            #             dist_f = dist3.copy()
            # if(Min < cost_now):
            #     cost_now = Min
            #     init_new_np = data[init_f.copy()]
            #     print("Round",ct,"Has A Swap", cost_now)
            #     ct += 1
            # else:
            #     k_id += 1

        # mutation
        if (cost_now < cost_glob):
            # print("CHECK",cost_now,cost_glob)
            cost_glob = cost_now
        else:
            fail += 1
        # print("fail",fail)
        # mutation
        init_set = set(init.copy())
        al = [i for i in range(0, data.shape[0])]
        al = set(al)
        al.difference(init_set)
        al = list(al)
        rd1 = random.sample(range(0, len(al)), 1)[0]
        rd2 = random.sample(range(0, len(init)), 1)[0]
        init[rd2] = al[rd1]
        init_np = data[init.copy()]
        TreeF = BallTree(init_np, leaf_size=40)
        distF, _ = TreeF.query(data.copy(), k=1)
        distF = distF[:, 0] ** 2
        cost_now = (distF * W).sum()
        prob = (distF * W) / ((distF * W).sum())

        # print("---RESTRAT---")

        if (fail >= 0):
            return cost_glob
            print("--------------------Final Clustering Cost---------------------", cost_glob)
            break


def generate_candidate_set(center_list):
    list_f = []
    for i in range(1, len(center_list) + 1):
        for j in combinations(center_list, i):
            list_f.append(list(j))
    return list_f


def fast_local_search1(data, init, k, ybxl, r, t, c, alpha, epsilon, W):
    print("-----------------------------------Start Fast Local Search------------------------------------")
    inf = 1000000000000000000000000000000000000000005
    groups = 5
    sample_size = math.ceil(3 * k * math.log2(k / epsilon) / epsilon)
    cost_glob = inf
    id_range = [i for i in range(0, data.shape[0])]
    fail = 0
    swap_count = np.zeros(data.shape[0])
    # Parallelization

    while (1):
        init_np = data.copy()[init]
        Tree = BallTree(init_np.copy(), leaf_size=40)
        dist, ind = Tree.query(data.copy(), k=t)
        ind = np.array(ind, dtype=int)
        dist = dist[:, 0] ** 2
        dist = dist * W
        prob = dist.copy() / dist.copy().sum()
        id_range_f = [i for i in range(0, data.shape[0])]
        cost_now = dist.sum()

        # groups = math.ceil(c*alpha/(alpha-1))

        # groups = 1
        # sample_size = data.shape[0]

        print("Check Size", "Sample Size", sample_size, "Group Size", groups)
        tswap = time.time()
        for i in range(0, r):
            # print("Count",i)
            # print(init)
            # nextpoint = np.random.choice(id_range,1,replace=False,p=prob)[0]
            prob_boost = prob * 3
            p1 = np.random.rand(data.shape[0])
            pdiff = prob_boost - p1
            id_range = np.argwhere(pdiff >= 0)
            id_range = id_range[:, 0]
            if (len(id_range) < 1 or len(id_range) > 3):
                continue
            next_points = generate_candidate_set(list(id_range))
            # print("Check Swap Size", len(next_points))
            init_f_np = 0
            init_f = 0
            id_out = []
            Min = inf
            for t1 in range(0, len(next_points)):
                if (len(next_points[t1]) == 1):
                    nextpoint = next_points[t1][0]
                    for j in range(0, k):
                        # construct new clutering centers
                        init_new = init.copy()
                        init_new[j] = nextpoint
                        init_new_np = data[init_new]
                        Tree1 = BallTree(init_new_np, leaf_size=40)
                        for s in range(0, groups):
                            id_sample = random.sample(range(0, data.shape[0]), sample_size)
                            id_sample = np.array(id_sample, dtype=int)
                            data_s = data[id_sample]
                            W_s = W.copy()[id_sample]
                            dist1, _ = Tree1.query(data_s, k=1)
                            dist1 = dist1[:, 0] ** 2
                            dist1 = dist1 * W_s
                            cost1 = dist1.sum()
                            if (cost1 < Min):
                                id_out = []
                                id_out.append(init[j])
                                Min = cost1
                                init_f = init_new.copy()
                                init_f_np = init_new_np.copy()
                elif (len(next_points[t1]) == 2):
                    nextpoint = next_points[t1]
                    for j in range(0, k):
                        for j1 in range(j + 1, k):
                            init_new = init.copy()
                            init_new[j] = nextpoint[0]
                            init_new[j1] = nextpoint[1]
                            init_new_np = data[init_new]
                            Tree1 = BallTree(init_new_np, leaf_size=40)
                            for s in range(0, groups):
                                id_sample = random.sample(range(0, data.shape[0]), sample_size)
                                id_sample = np.array(id_sample, dtype=int)
                                data_s = data[id_sample]
                                W_s = W.copy()[id_sample]
                                dist1, _ = Tree1.query(data_s, k=1)
                                dist1 = dist1[:, 0] ** 2
                                dist1 = dist1 * W_s
                                cost1 = dist1.sum()
                                if (cost1 < Min):
                                    id_out = []
                                    id_out.append(init[j])
                                    id_out.append(init[j1])
                                    Min = cost1
                                    init_f = init_new.copy()
                                    init_f_np = init_new_np.copy()
                else:
                    nextpoint = next_points[t1]
                    for j in range(0, k):
                        for j1 in range(j + 1, k):
                            for j2 in range(j1 + 1, k):
                                init_new = init.copy()
                                init_new[j] = nextpoint[0]
                                init_new[j1] = nextpoint[1]
                                init_new[j2] = nextpoint[2]
                                init_new_np = data[init_new]
                                Tree1 = BallTree(init_new_np, leaf_size=40)
                                for s in range(0, groups):
                                    id_sample = random.sample(range(0, data.shape[0]), sample_size)
                                    id_sample = np.array(id_sample, dtype=int)
                                    data_s = data[id_sample]
                                    W_s = W.copy()[id_sample]
                                    dist1, _ = Tree1.query(data_s, k=1)
                                    dist1 = dist1[:, 0] ** 2
                                    dist1 = dist1 * W_s
                                    cost1 = dist1.sum()
                                    if (cost1 < Min):
                                        id_out = []
                                        id_out.append(init[j])
                                        id_out.append(init[j1])
                                        id_out.append(init[j2])
                                        Min = cost1
                                        init_f = init_new.copy()
                                        init_f_np = init_new_np.copy()
                # Check if it is a real swap
            # print("Checklen",len(next_points))
            Tree2 = BallTree(init_f_np, leaf_size=50)
            dist1, _ = Tree2.query(data.copy(), k=1)
            dist1 = dist1[:, 0] ** 2
            dist1 = dist1 * W
            cost_next = dist1.sum()
            if (cost_next < cost_now):
                for m in range(0, len(id_out)):
                    swap_count[id_out[m]] -= 1
                cost_now = cost_next
                prob = dist1.copy() / dist1.copy().sum()
                init = init_f.copy()
                print("Round", i, "Has A Swap", cost_now)

        ct = 0
        # Lloyd Type Search
        tswap1 = time.time()
        print("Local Search Takes Time", tswap1 - tswap)
        print("Enter the Lloyd")
        tlloyd = time.time()
        init_L = data[init]
        TreeL = BallTree(init_L.copy(), leaf_size=40)
        _, indL = TreeL.query(data.copy(), k=1)
        indL = indL[:, 0]
        while (1):
            # Create new centers
            init_L_new = init_L.copy()
            for i1 in range(0, k):
                id_i = np.argwhere(indL == i1)
                id_i = id_i[:, 0]
                W_repeat = np.tile(W[id_i], (data.shape[1], 1))
                W_repeat = W_repeat.transpose()
                init_L_new[i1] = np.sum(data[id_i] * W_repeat, axis=0) / W[id_i].sum()
            # Find the nearest data points to approximate new centers
            # print(init_L_new)
            TreeL1 = BallTree(data.copy(), leaf_size=40)
            _, indL1 = TreeL1.query(init_L_new, k=1)
            indL1 = indL1[:, 0]
            # print("Check For New Centers", indL1)
            init_temp = data[indL1]
            # Recalculate the cost
            TreeL = BallTree(init_temp.copy(), leaf_size=40)
            distL, indL = TreeL.query(data.copy(), k=1)
            indL = indL[:, 0]
            distL = distL[:, 0] ** 2
            distL = distL * W
            if (distL.sum() < cost_now):
                cost_now = distL.sum()
                init = indL1.copy()
                init_L = init_temp.copy()
                print("Lloyd Has A Swap", cost_now)
            else:
                print("No Better", distL.sum())
                break
        tlloyd1 = time.time()
        print("Lloyd Takes Time", tlloyd1 - tlloyd)

        tneighbor = time.time()
        print("Enter the Nearest Neighbot Search")
        # print("Check For Init", init)
        init_new_np = data[init]
        k_list = [10, 20, 30, 50]
        k_id = 0
        dist_f = 0
        while (1):
            if (k_id >= len(k_list)):
                break
            Tree2 = BallTree(data.copy(), leaf_size=40)
            _, ind2 = Tree2.query(init_new_np, k=k_list[k_id])
            # print(ind2)
            Min = 10000000000000000000000000000000005
            init_f = 0
            init_f_id = 0
            id_out = -1
            id_int = -1
            for i2 in range(0, k):
                for j in range(0, ind2.shape[1]):
                    init_f_temp = init.copy()
                    init_f_temp[i2] = ind2[i2][j]
                    init_new_np1 = init_new_np.copy()
                    init_new_np1[i2] = data[ind2[i2][j]]
                    Tree3 = BallTree(init_new_np1, leaf_size=40)
                    dist3, _ = Tree3.query(data, k=1)
                    dist3 = dist3[:, 0] ** 2
                    dist3 = dist3 * W
                    if (dist3.sum() < Min):
                        id_in = ind2[i2][j]
                        id_out = init[i2]
                        Min = dist3.sum()
                        init_f = init_new_np1.copy()
                        init_f_id = init_f_temp.copy()
                        dist_f = dist3.copy()
            if (Min < cost_now):
                swap_count[id_out] -= 1
                cost_now = Min
                init_new_np = init_f.copy()
                init = init_f_id.copy()
                print("Round", ct, "Has A Swap", cost_now)
                ct += 1
            else:
                k_id += 1

        tneighbor1 = time.time()
        print("Neighbor Search Takes Time", tneighbor1 - tneighbor)

        swap_count[init] += 1
        # mutation
        for i3 in range(0, len(init)):
            swap_count[init[i3]] += 1
        if (cost_now < cost_glob):
            cost_glob = cost_now
        else:
            fail += 1

        score_id = np.argsort(-swap_count)[0:k]
        score_id_set = set(score_id)
        org_set = set(init.copy())
        score_id_set = score_id_set.difference(org_set)
        if (len(score_id_set) < 1):
            print("Random Mutation")
            prob_mu = dist_f.copy() / dist_f.copy().sum()
            # print(prob_mu.shape)
            mu_id = random.sample(range(0, k), 1)[0]
            swap_id = np.random.choice(id_range_f, 1, replace=False, p=prob_mu)[0]
            init[mu_id] = swap_id
        else:
            print("Score Mutation")
            mu_id = random.sample(range(0, k), 1)[0]
            swap_id = (list(score_id_set))[0]
            init[mu_id] = swap_id
        if (fail >= 0):
            print("---------------------Final Clustering Cost---------------------", cost_glob)
            break


def init_worker(data, data_shape, w, w_shape):
    var_dict['data'] = data
    var_dict['data_shape'] = data_shape
    var_dict['w'] = w
    var_dict['w_shape'] = w_shape


def swap_t(args):
    if (args[0] == "LOCAL_SEARCH"):
        return SWAP_SEARCH1(*args[1:])
    if (args[0] == "LOCAL_SEARCH1"):
        return SWAP_SEARCH2(*args[1:])
    if (args[0] == "NEIGHBOR_SEARCH"):
        return NEIGHBOR_SEARCH(*args[1:])


def SWAP_SEARCH1(groups, L, init):
    Min = 100000000000000000000000000000000005
    id_out = []
    id_in = []
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    j = L[0]
    nextpoint = L[1]
    group = [groups[i][0] for i in range(0, len(groups))]
    W1 = [groups[i][1] for i in range(0, len(groups))]
    init_new = init.copy()
    for i in range(0, len(j)):
        init_new[j[i]] = nextpoint[i]
    init_new_np = data1[init_new]
    Tree1 = BallTree(init_new_np, leaf_size=40)
    for s in range(0, len(group)):
        data_s = group[s]
        W_s = W1[s]
        dist1, _ = Tree1.query(data_s, k=1)
        dist1 = dist1[:, 0] ** 2
        dist1 = dist1 * W_s
        cost1 = dist1.sum()
        if (cost1 < Min):
            Min = cost1
            id_out = []
            id_out.append(j)
            id_in = []
            id_in.append(nextpoint)
    return Min, id_out, id_in


def SWAP_SEARCH2(groups, sample_size, L, init):
    Min = 100000000000000000000000000000000005
    id_out = []
    id_in = []
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    w = np.frombuffer(var_dict['w']).reshape(var_dict['w_shape'])
    nextpoint = L[0]
    lg = len(nextpoint)
    if (lg == 1):
        for q in range(0, k):
            init_new = init.copy()
            init_new[q] = nextpoint[0]
            init_new_np = data1[init_new]
            Tree1 = BallTree(init_new_np, leaf_size=40)
            for s in range(0, groups):
                id_sample = random.sample(range(0, data1.shape[0]), sample_size)
                id_sample = np.array(id_sample, dtype=int)
                data_s = data1[id_sample]
                W_s = w.copy()[id_sample]
                dist1, _ = Tree1.query(data_s, k=1)
                dist1 = dist1[:, 0] ** 2
                dist1 = dist1 * W_s
                cost1 = dist1.sum()
                if (cost1 < Min):
                    id_out = []
                    id_in = []
                    id_out.append(q)
                    id_in.append(nextpoint[0])
                    Min = cost1
    elif (lg == 2):
        for q in range(0, k):
            for q1 in range(q + 1, k):
                init_new = init.copy()
                init_new[q] = nextpoint[0]
                init_new[q1] = nextpoint[1]
                init_new_np = data1[init_new]
                Tree1 = BallTree(init_new_np, leaf_size=40)
                for s in range(0, groups):
                    id_sample = random.sample(range(0, data1.shape[0]), sample_size)
                    id_sample = np.array(id_sample, dtype=int)
                    data_s = data1[id_sample]
                    W_s = w.copy()[id_sample]
                    dist1, _ = Tree1.query(data_s, k=1)
                    dist1 = dist1[:, 0] ** 2
                    dist1 = dist1 * W_s
                    cost1 = dist1.sum()
                    if (cost1 < Min):
                        id_out = []
                        id_in = []
                        id_out.append(q)
                        id_out.append(q1)
                        id_in.append(nextpoint[0])
                        id_in.append(nextpoint[1])
                        Min = cost1
    else:
        for q in range(0, k):
            for q1 in range(q + 1, k):
                for q2 in range(q1 + 1):
                    init_new = init.copy()
                    init_new[q] = nextpoint[0]
                    init_new[q1] = nextpoint[1]
                    init_new[q2] = nextpoint[2]
                    init_new_np = data1[init_new]
                    Tree1 = BallTree(init_new_np, leaf_size=40)
                for s in range(0, groups):
                    id_sample = random.sample(range(0, data1.shape[0]), sample_size)
                    id_sample = np.array(id_sample, dtype=int)
                    data_s = data1[id_sample]
                    W_s = w.copy()[id_sample]
                    dist1, _ = Tree1.query(data_s, k=1)
                    dist1 = dist1[:, 0] ** 2
                    dist1 = dist1 * W_s
                    cost1 = dist1.sum()
                    if (cost1 < Min):
                        id_out = []
                        id_in = []
                        id_out.append(q)
                        id_out.append(q1)
                        id_out.append(q2)
                        id_in.append(nextpoint[0])
                        id_in.append(nextpoint[1])
                        id_in.append(nextpoint[2])
                        Min = cost1
    return Min, id_out, id_in


def NEIGHBOR_SEARCH(L, init):
    i = L[0]
    j = L[1]
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    w = np.frombuffer(var_dict['w']).reshape(var_dict['w_shape'])
    init_f_temp = init.copy()
    init_f_temp[i] = j
    init_new_np1 = data1[init_f_temp]

    Treeq = BallTree(init_new_np1, leaf_size=40)
    dist, _ = Treeq.query(data1, k=1)
    dist = dist[:, 0] ** 2
    cost_f = (dist * w).sum()
    return cost_f, i, j, dist


def kmeans_plus(data, k):
    init = []
    prob = 0
    id_range = [i for i in range(0, data.shape[0])]
    for i in range(0, k):
        if (i == 0):
            rid = random.sample(range(0, data.shape[0]), 1)[0]
            init.append(rid)
        else:
            nextpoint = np.random.choice(id_range, p=prob, size=1, replace=False)[0]
            init.append(nextpoint)

        init_id = np.array(init.copy(), dtype=int)
        init_np = data[init_id]
        Tree = BallTree(init_np.copy(), leaf_size=40)
        dist, _ = Tree.query(data.copy(), k=1)
        dist = dist[:, 0] ** 2
        prob = dist.copy() / (dist.copy()).sum()

    return init


def Lloyd(data, k, W):
    km = KMeans(n_clusters=k, init="k-means++", n_init=1, max_iter=10)
    km.fit(data)
    centers = km.cluster_centers_
    Tree = BallTree(data, leaf_size=40)
    _, ind = Tree.query(centers, k=1)
    ind = ind[:, 0]
    center_new = data[ind]
    Tree1 = BallTree(center_new, leaf_size=40)
    dist, _ = Tree1.query(data, k=1)
    dist = dist[:, 0] ** 2
    print("------------------------Lloyd--------------------------", (dist * W).sum())
    return (dist * W).sum()


def fast_local_search2(data, init, k, ybxl, r, t, c, alpha, epsilon, W, single, multi, triple):
    # print("-----------------------------------Start Fast Local Search------------------------------------")
    data_shape = (data.shape[0], data.shape[1])
    data_g = RawArray('d', data_shape[0] * data_shape[1])
    data_np = np.frombuffer(data_g).reshape(data_shape)
    w_shape = (W.shape[0])
    w_g = RawArray('d', w_shape)
    w_np = np.frombuffer(w_g).reshape(w_shape)
    np.copyto(data_np, data)
    np.copyto(w_np, W)
    inf = 1000000000000000000000000000000000000000005
    groups = 5
    sample_size = math.ceil(3 * k * math.log2(k / epsilon) / epsilon)
    sample_size = min(data.shape[0], sample_size)
    cost_glob = inf
    id_range = [i for i in range(0, data.shape[0])]
    fail = 0
    swap_count = np.zeros(data.shape[0])
    # Parallelization
    pool1 = Pool(processes=MAX_PROCESSOR, initializer=init_worker, initargs=(data_g, data_shape, w_g, w_shape))
    count_tot = 0
    init_ff = -1
    while (1):
        init_np = data.copy()[init]
        Tree = BallTree(init_np.copy(), leaf_size=40)
        dist, ind = Tree.query(data.copy(), k=t)
        ind = np.array(ind, dtype=int)
        dist = dist[:, 0] ** 2
        dist = dist * W
        prob = dist.copy() / dist.copy().sum()
        id_range_f = [i for i in range(0, data.shape[0])]
        cost_now = dist.sum()

        # groups = math.ceil(c*alpha/(alpha-1))

        # groups = 1
        # sample_size = data.shape[0]
        # print("Check Size", "Sample Size", sample_size, "Group Size", groups)
        tswap = time.time()
        for i in range(0, r):
            tswap1 = time.time()
            # print("Count",i)
            # print(init)
            # nextpoint = np.random.choice(id_range,1,replace=False,p=prob)[0]
            prob_boost = prob * 1.5
            p1 = np.random.rand(data.shape[0])
            pdiff = prob_boost - p1
            id_range = np.argwhere(pdiff >= 0)
            id_range = id_range[:, 0]
            if (len(id_range) < 1 or len(id_range) > 2):
                continue
            next_points = generate_candidate_set(list(id_range))
            # print("Check Swap Size", len(next_points))
            init_f_np = 0
            init_f = 0
            id_out = []
            id_in = []
            Min = inf
            sample_groups = []
            for s in range(0, groups):
                id_sample = random.sample(range(0, data.shape[0]), sample_size)
                id_sample = np.array(id_sample, dtype=int)
                data_s = (data.copy())[id_sample]
                w_s = (W.copy())[id_sample]
                sample_groups.append([data_s, w_s])

            # Plan A
            pair = []
            tpair = time.time()
            for f in range(0, len(next_points)):
                if (len(next_points[f]) == 1):
                    pair = pair + [[[single[n1]], next_points[f]] for n1 in range(0, len(single))]
                elif (len(next_points[f]) == 2):
                    pair = pair + [[multi[n1], next_points[f]] for n1 in range(0, len(multi))]
                else:
                    pair = pair + [[triple[n1], next_points[f]] for n1 in range(0, len(triple))]
            tpair1 = time.time()
            # print("Pair Construction Time", tpair1 - tpair)
            r_tot = list(pool1.imap(swap_t, [("LOCAL_SEARCH", sample_groups, pair[b], init.copy()) for b in
                                             range(0, len(pair))]))

            # Plan B
            # pair = [[next_points[b]] for b in range(0, len(next_points))]
            # r_tot = list(pool1.imap(swap_t,[("LOCAL_SEARCH1", groups, sample_size, pair[b], init.copy()) for b in range(0,len(pair))]))

            # r_tot = []
            # for qq in range(0,len(pair)):
            #     r_tot.append(list(SWAP_SEARCH1(data.copy(),sample_groups,pair[qq],init.copy())))

            for f in range(0, len(r_tot)):
                if (r_tot[f][0] < Min):
                    Min = r_tot[f][0]
                    id_out = r_tot[f][1]
                    id_in = r_tot[f][2]
            init_f = init.copy()
            for n1 in range(0, len(id_in)):
                init_f[id_out[n1]] = id_in[n1]
            init_f_np = (data.copy())[init_f]

            # print("One Round Swap Length", len(next_points),"Takes Time", tswap2 - tswap1)
            Tree2 = BallTree(init_f_np, leaf_size=50)
            dist1, _ = Tree2.query(data.copy(), k=1)
            dist1 = dist1[:, 0] ** 2
            dist1 = dist1 * W
            cost_next = dist1.sum()
            if (cost_next < cost_now * (1 - 1 / (100 * k))):
                for m in range(0, len(id_out)):
                    swap_count[id_out[m]] -= 1
                cost_now = cost_next
                prob = dist1.copy() / dist1.copy().sum()
                init = init_f.copy()
                # print("Round", i, "Has A Swap", cost_now)

        tswap1 = time.time()
        # print("Swap Takes Time", tswap1 - tswap)
        tlloyd = time.time()
        ct = 0
        # Lloyd Type Search
        # print("Enter the Lloyd")
        init_L = data[init]
        TreeL = BallTree(init_L.copy(), leaf_size=40)
        _, indL = TreeL.query(data.copy(), k=1)
        indL = indL[:, 0]
        while (1):
            # Create new centers
            init_L_new = init_L.copy()
            for i1 in range(0, k):
                id_i = np.argwhere(indL == i1)
                id_i = id_i[:, 0]
                W_repeat = np.tile(W[id_i], (data.shape[1], 1))
                W_repeat = W_repeat.transpose()
                init_L_new[i1] = np.sum(data[id_i] * W_repeat, axis=0) / W[id_i].sum()
            # Find the nearest data points to approximate new centers
            # print(init_L_new)
            TreeL1 = BallTree(data.copy(), leaf_size=40)
            _, indL1 = TreeL1.query(init_L_new, k=1)
            indL1 = indL1[:, 0]
            # print("Check For New Centers", indL1)
            init_temp = data[indL1]
            # Recalculate the cost
            TreeL = BallTree(init_temp.copy(), leaf_size=40)
            distL, indL = TreeL.query(data.copy(), k=1)
            indL = indL[:, 0]
            distL = distL[:, 0] ** 2
            distL = distL * W
            if (distL.sum() < cost_now):
                cost_now = distL.sum()
                init = indL1.copy()
                init_L = init_temp.copy()
                # print("Lloyd Has A Swap", cost_now)
            else:
                # print("No Better", distL.sum())
                break

        tlloyd1 = time.time()
        # print("Lloyd Takes Time", tlloyd1 - tlloyd)

        tneighbor = time.time()
        # print("Enter the Nearest Neighbot Search")
        # print("Check For Init", init)
        init_new_np = data[init]
        k_list = [10, 20, 30, 50, 100]
        k_id = 0
        dist_f = 0

        # TreeS = BallTree(init_new_np, leaf_size=40)
        # distS, indS = TreeS.query(data.copy(), k=k)
        # distS = distS ** 2

        while (1):
            if (k_id >= len(k_list)):
                break
            Tree2 = BallTree(data.copy(), leaf_size=40)
            _, ind2 = Tree2.query(init_new_np, k=k_list[k_id])
            # print(ind2)
            Min = 10000000000000000000000000000000005
            init_f = 0
            init_f_id = 0
            id_out = -1
            id_in = -1
            pair = [[i2, ind2[i2][j], ] for i2 in range(0, k) for j in range(0, ind2.shape[1])]
            r_tot = list(pool1.imap(swap_t, [("NEIGHBOR_SEARCH", pair[b], init.copy()) for b in range(0, len(pair))]))
            for i2 in range(0, len(r_tot)):
                if (r_tot[i2][0] < Min):
                    id_in = r_tot[i2][2]
                    id_out = r_tot[i2][1]
                    Min = r_tot[i2][0]
                    dist_f = r_tot[i2][3]
            init_f_id = init.copy()
            init_f_id[id_out] = id_in
            init_f = data[init_f_id]

            # for i2 in range(0,k):
            #     dist_temp = (distS.copy())[:,0]
            #     #print("i2", i2, "Check_Cost",dist_temp.sum())
            #     affect_id = np.argwhere(indS[:,0] == i2)[:,0]
            #     dist_aff = (distS.copy()[affect_id])[:,1]
            #     dist_temp[affect_id] = dist_aff
            #     #print("Check_Cost",dist_temp.sum())
            #     for j in range(0,ind2.shape[1]):
            #         init_f_temp = init.copy()
            #         init_f_temp[i2] = ind2[i2][j]
            #         init_new_np1 = init_new_np.copy()
            #         init_new_np1[i2] = data[ind2[i2][j]]
            #         #PlanB for Calculation of Cost
            #         # Tree3 = BallTree(init_new_np1, leaf_size=40)
            #         # dist3, _ = Tree3.query(data,k=1)
            #         # dist3 = dist3[:,0] ** 2
            #         # dist3 = dist3 * W

            #         data_q = (data[ind2[i2][j]]).reshape(1,-1)
            #         Treeq = BallTree(data_q.copy(), leaf_size=40)
            #         distq, _ = Treeq.query(data.copy(),k=1)
            #         distq = distq[:,0] ** 2

            #         dist_tempq = dist_temp.copy()
            #         dist_diff = distq - dist_tempq
            #         dist_id = np.argwhere(dist_diff<0)[:,0]
            #         dist_tempq[dist_id] = distq[dist_id]
            #         cost_f = (dist_tempq * W).sum()
            #         #print("Check",cost_f)

            #         if(cost_f<Min):
            #             id_in = ind2[i2][j]
            #             id_out = init[i2]
            #             Min = cost_f
            #             init_f = init_new_np1.copy()
            #             init_f_id = init_f_temp.copy()
            #             dist_f = dist_tempq.copy()
            if (Min < cost_now):
                swap_count[id_out] -= 1
                cost_now = Min
                init_new_np = init_f.copy()
                init = init_f_id.copy()
                # print("Round", ct, "Has A Swap", cost_now)
                ct += 1
            else:
                k_id += 1

        tneighbor1 = time.time()
        # print("Neighbor Search Takes Time", tneighbor1 - tneighbor)
        swap_count[init] += 1

        # cs = ['red','orange','yellow','green','cyan','blue','purple','pink','magenta','brown']
        Treep = BallTree(init_f, leaf_size=40)
        _, indp = Treep.query(data.copy(), k=1)
        indp = indp[:, 0]
        # fig = plt.figure()
        # ax = Axes3D(fig)
        # ax.set_xlim(-0.2,0.6)
        # ax.set_ylim(-0.4,0.4)
        # ax.set_zlim(0,0.8)
        # for i3 in range(0,k):
        #     indi3 = np.argwhere(indp==i3)
        #     indi3 = indi3[:,0]
        #     ax.scatter((data[indi3])[:,0],(data[indi3])[:,1],(data[indi3])[:,2], c=cs[i3])
        # plt.show()

        # print("Centers",init_f)

        # mutation

        if (cost_now < cost_glob):
            init_ff = init.copy()
            for i3 in range(0, len(init)):
                swap_count[init[i3]] += 1
            cost_glob = cost_now
        else:
            fail += 1

        score_id = np.argsort(-swap_count)[0:k]
        score_id_set = set(score_id)
        org_set = set(init.copy())
        score_id_set = score_id_set.difference(org_set)
        Treeh = BallTree(init_new_np.copy(), leaf_size=40)
        _, indh = Treeh.query(data, k=1)
        indh = indh[:, 0]
        std_max = -1
        std_min = inf
        ind_std = -1
        id_min = - 1
        id_max = -1
        for i3 in range(0, k):
            id3 = np.argwhere(indh == i3)
            id3 = id3[:, 0]
            data_m = data[id3]
            std = np.var(data_m, axis=0)
            std_t = std.sum()
            # print("STD",std_t)
            if (std_t > std_max):
                std_max = std_t
                ind_std = id3.copy()
                id_max = i3
            if (std_t < std_min):
                std_min = std_t
                id_min = i3

        # fig1 = plt.figure()
        # ax1 = Axes3D(fig1)
        # ax1.set_xlim(-0.2,0.6)
        # ax1.set_ylim(-0.4,0.4)
        # ax1.set_zlim(0,0.8)
        # ax1.scatter((data[ind_std])[:,0],(data[ind_std])[:,1],(data[ind_std])[:,2], c=cs[0])
        # plt.show()

        if (len(score_id_set) < 5 or count_tot <= 2):
            # print("Random Mutation")

            # Plan C

            init_ex = init.copy()
            init_ex = list(init_ex)
            t_s1 = math.ceil(k * math.log2(k / 0.5) / epsilon)
            t_s1 = min(math.floor(2 * k), t_s1)
            rd_s1 = random.sample(range(0, data.shape[0]), t_s1)
            # init_ex = init_ex + list(rd_s1)
            init_ex = list(rd_s1)
            l_up = len(init_ex + list(init.copy()))
            init_ex = np.array(init_ex, dtype=int)
            init_npe = data[init_ex]

            Tree_e = BallTree(init_npe.copy(), leaf_size=40)
            dist_e, ind_e = Tree_e.query(data.copy(), k=1)
            dist_e = dist_e[:, 0] ** 2
            dist_e = dist_e * W
            ind_e = ind_e[:, 0]
            cost_ex = (dist_e * W).sum()

            s_temp = set(list(init.copy()) + list(init_ex.copy()))
            s_temp_list = np.array(list(s_temp), dtype=int)

            score_f = np.zeros(l_up)
            # dist_ex = dist_f.copy() * W
            # prob_ex = dist_ex.copy() / (dist_ex.copy()).sum()
            # #print("CHECK",id_range.shape,prob_ex.shape)
            # for i in range(0,k):
            #     nextpoint = np.random.choice(id_range_f,p=prob_ex,size=1,replace=False)[0]
            #     init_ex.append(nextpoint)

            #     init_ide = np.array(init_ex.copy(),dtype=int)
            #     init_npe = data[init_ide]
            #     Tree_e = BallTree(init_npe.copy(), leaf_size=40)
            #     dist_e, ind_e = Tree_e.query(data.copy(),k=1)
            #     dist_e = dist_e[:,0] ** 2
            #     dist_e = dist_e * W
            #     ind_e = ind_e[:,0]
            #     prob_ex = dist.copy() / (dist.copy()).sum()
            # cost_ex = dist_e.sum()

            # print("CHECK for list",init,init_ex)
            # Calculate the weights for mutation

            # print("CHECK", cost_ex)
            for i5 in range(0, len(init_ex)):
                if (i5 < k):
                    id5 = (np.argwhere(indp == i5))[:, 0]
                    id6 = (np.argwhere(ind_e == i5))[:, 0]
                    idi5 = (np.argwhere(s_temp_list == init[i5]))[:, 0]
                    cost5 = (dist_f[id5] * W[id5]).sum()
                    cost6 = (dist_e[id6] * W[id6]).sum()
                    idi6 = (np.argwhere(s_temp_list == init_ex[i5]))[:, 0]
                    score_f[idi5] += 1 * cost5 / cost_now
                    score_f[idi6] += 3 * cost6 / cost_ex
                else:
                    id6 = (np.argwhere(ind_e == i5))[:, 0]
                    cost6 = (dist_e[id6] * W[id6]).sum()
                    idi6 = (np.argwhere(s_temp_list == init_ex[i5]))[:, 0]
                    score_f[idi6] += 3 * cost6 / cost_ex
            scores3 = score_f.copy()
            # normalization
            id_score = np.argsort(-scores3)
            init_ex = np.array(init_ex, dtype=int)
            # print("OLD INIT",scores1)
            init_check = init.copy()
            # print("CHECK", s_temp_list)
            init = (s_temp_list.copy())[id_score[0:k]]
            # print("NEW INIT",scores2)

            # print("-----------------Difference---------------", len(set(init_check).difference(set(init))))

            # prob_boost = np.ones(len(init)) * (1/len(init)) * 2
            # p1 = np.random.rand(len(init))

            # pdiff = prob_boost - p1

            # id_range = np.argwhere(pdiff>=0)
            # id_range = id_range[:,0]
            # print("------Mutation Number------",len(id_range))
            # if(len(id_range)==0):
            #     rd1 = random.sample(range(0,len(init)),1)[0]
            #     rd2 = random.sample(range(0,data.shape[0]),1)[0]
            #     init[rd1] = rd2
            # else:
            #     prob_mu = 0.5 * dist_f.copy() / dist_f.copy().sum() + 0.5 * np.ones(data.shape[0]) * (1 / data.shape[0])
            #     #print(prob_mu.shape)
            #     mu_id = []
            #     for i3 in range(0,len(id_range)):
            #         swap_id = np.random.choice(id_range_f,1,replace=False,p=prob_mu)[0]
            #         mu_id.append(swap_id)
            #     mu_id = np.array(mu_id,dtype=int)
            #     init[id_range] = mu_id
            # Plan B: Close a small cluster and split a huge cluster

            # Plan B
            # rd1 = random.sample(range(0,len(ind_std)),2)
            # rd2 = rd1[1]
            # rd1 = rd1[0]
            # data_q = data[ind_std[rd1]]
            # data_q = data_q.reshape(1,-1)
            # Treeq = BallTree(data_q,leaf_size=40)
            # distq, _ = Treeq.query(data,k=1)
            # distq = distq[:,0]
            # #rd2 = (np.argsort(-distq))[0]
            # init[id_min] = ind_std[rd1]
            # init[id_max] = rd2

            # init_test = data[init]
            # Tree_test = BallTree(init_test,leaf_size=40)
            # dist_test, ind_test = Tree_test.query(data,k=1)
            # ind_test = ind_test[:,0]
            # dist_test = dist_test[:,0] ** 2
            # cost_test = (dist_test * W).sum()
            # print("NEW INIT COST",cost_test)

            # fig3 = plt.figure()
            # ax3 = Axes3D(fig3)
            # ax3.set_xlim(-0.2,0.6)
            # ax3.set_ylim(-0.4,0.4)
            # ax3.set_zlim(0,0.8)
            # for i3 in range(0,k):
            #     indi3 = np.argwhere(ind_test==i3)
            #     indi3 = indi3[:,0]
            #     ax3.scatter((data[indi3])[:,0],(data[indi3])[:,1],(data[indi3])[:,2], c=cs[i3])
            # plt.show()

        else:
            # print("Score Mutation")
            swap_id = (list(score_id_set))[0]
            init[id_min] = swap_id
        count_tot += 1
        if (fail >= 5):
            # print("CHECK", init_ff)
            init_ff = np.array(init_ff, dtype=int)
            print("---------------------Final Clustering Cost---------------------", cost_glob)
            return cost_glob
            break


def LLOYD1(data, init, W, k):
    init_L = data[init]
    TreeL = BallTree(init_L.copy(), leaf_size=40)
    _, indL = TreeL.query(data.copy(), k=1)
    indL = indL[:, 0]
    cost_now = 1000000000000000000000000000000000000000000000000005
    for i in range(10):
        # Create new centers
        init_L_new = init_L.copy()
        for i1 in range(0, k):
            id_i = np.argwhere(indL == i1)
            id_i = id_i[:, 0]
            W_repeat = np.tile(W[id_i], (data.shape[1], 1))
            W_repeat = W_repeat.transpose()
            init_L_new[i1] = np.sum(data[id_i] * W_repeat, axis=0) / W[id_i].sum()
        # Find the nearest data points to approximate new centers
        # print(init_L_new)
        TreeL1 = BallTree(data.copy(), leaf_size=40)
        _, indL1 = TreeL1.query(init_L_new, k=1)
        indL1 = indL1[:, 0]
        # print("Check For New Centers", indL1)
        init_temp = data[indL1]
        # Recalculate the cost
        TreeL = BallTree(init_temp.copy(), leaf_size=40)
        distL, indL = TreeL.query(data.copy(), k=1)
        indL = indL[:, 0]
        distL = distL[:, 0] ** 2
        distL = distL * W
        if (distL.sum() < cost_now):
            cost_now = distL.sum()
            init = indL1.copy()
            init_L = init_temp.copy()
            print("Lloyd Has A Swap", cost_now)
        else:
            print("No Better", distL.sum())
            break

    return cost_now


def Projection(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return init_temp


def k_means_cost(points, centers):
    distance = euclidean_distances(points, centers)
    distance = distance ** 2
    labels = np.argmin(distance, axis=1)
    return labels, np.min(distance, axis=1).sum()


if __name__ == '__main__':
    total_list = []
    total_list1 = []
    # name_p = ['rds']
    # name_p = ['iris','seeds','glass']
    # name_p = ['UK_L','HF_L', 'Abs_FL']
    # name_p = ['iris','seeds','glass','HCV_L','BM_FL','UK_L','HF_L', 'Abs_FL','AC_FL','GT_FL','hemi','HF_L','HTRU2_L','pr2392','Who_FL','TR_FL','SGC_FL','AC_FL']
    # name_p = ['rds']
    name_p = ['SIFT']
    for d_i in range(len(name_p)):
        for d_j in np.array([10]):
            cost_list = []
            cost_list1 = []
            first_name = str(name_p[d_i]) + "_k=" + str(d_j)
            print(first_name)
            cost_list.append(first_name)
            cost_list1.append(first_name)
            k = d_j
            print("k:", k)
            dataset_cost10_list = []
            dataset_cost10_list1 = []
            single = [i for i in range(0, k)]
            multi = [[i, j] for i in range(0, k) for j in range(i + 1, k)]
            triple = [[i, j, j1] for i in range(0, k) for j in range(i + 1, k) for j1 in range(j + 1, k)]
            data_path = "data/" + str(name_p[d_i])
            if (name_p[d_i] == "USC1990"):
                data_path = data_path + ".txt"
                data = np.loadtxt(data_path, delimiter=',')
            elif (name_p[d_i] == "SUSY"):
                data_path = data_path + ".csv"
                data = np.loadtxt(data_path, usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                                  delimiter=',')
            elif (name_p[d_i] == "HIGGS"):
                data_path = data_path + ".csv"
                data = np.loadtxt(data_path, delimiter=',')
                data = data[:, 1:data.shape[1]]
            elif (name_p[d_i] == "SIFT"):
                x = np.memmap("learn.bvecs", dtype='uint8', mode='r')
                d = x[:4].view('int32')[0]
                data = x.reshape(-1, d + 4)[:, 4:]
                data = np.array(data)
            else:
                data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')
  
            # data = data[:, 0:data.shape[1] - 1]

            # data_org = data.copy()
            # W_org = np.ones(data_org.shape[0])
            # data, W = np.unique(data, axis=0, return_counts=True)

            # W_org = np.ones(len(data))

            tt = 0
            tt1 = 0
            for d_k in range(10):
                # data, _ = make_blobs(n_samples=1000000,n_features=20,cluster_std=1.5)

                # data = np.loadtxt('yeast.txt',usecols = (1,2,3,4,5,6,7,8))
                # data, _ = make_blobs(n_samples=1000000,n_features=40,centers=10)
                # data = np.loadtxt('SUSY.csv',usecols=(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18), delimiter=",")

                # initialization
                init_id = random.sample(range(0, data.shape[0]), k)
                # init_id = kmeans_plus(data.copy(),k)
                init_id = np.array(init_id, dtype=int)

                # Lloyd(data, k, W)

                t0 = time.time()
                # data_cost = fast_local_search(data.copy(), init_id.copy(), 1/100, 400, 2, k, W.copy())
                # dataset_cost10_list1.append(data_cost_1)

                # t1 = time.time()
                # fast_local_search1(data.copy(), init_id.copy(), k, 1/100, 400, 2, 10, 2, 0.5, W.copy())

                # fast_local_search2(data.copy(), init_id.copy(), k, 1/100, 1000, 2, 10, 2, 0.5)
                # t3 = time.time()

                # data_cost = fast_local_search2(data.copy(), init_id.copy(), k, 1/100, 400, 2, 10, 2, 0.5, W.copy(), single.copy(), multi.copy(), triple.copy())

                print("ready to start")

                km = MiniBatchKMeans(n_clusters=k, init='random', n_init=1, max_iter=5, batch_size=1024,
                                     compute_labels=True)
                #km.fit(data)
                # print(km.inertia_)
                #data_cost = km.inertia_
                
                cluster_centers_ = data[random.sample(range(0,data.shape[0]),k)]
                
                print("lloyd finished")

                distance = euclidean_distances(km.cluster_centers_, data)
                distance = distance ** 2
                labels = np.argmin(distance, axis=1)
                centers_new = data[labels]

                _, data_cost = k_means_cost(data, centers_new)

                # data_cost = LLOYD1(data.copy(), init_id.copy(), W_org, k)
                t2 = time.time()
                # print(data_cost)
                dataset_cost10_list.append(data_cost)
                t3 = time.time()

                tt1 += t2 - t0
                tt += t2 - t0
                # print(t2 - t1)
                # print(t3 - t2)
            # print(dataset_cost10_list)
            # print("ORG",tt1)
            print("NEW", tt)
            tt = tt / 10
            tt1 = tt1 / 10
            dataset_cost10_list = np.array(dataset_cost10_list)
            min_cost = np.min(dataset_cost10_list)
            max_cost = np.max(dataset_cost10_list)
            mean_cost = np.mean(dataset_cost10_list)
            std_cost = np.std(dataset_cost10_list)
            cost_list.append(min_cost)
            cost_list.append(max_cost)
            cost_list.append(mean_cost)
            cost_list.append(std_cost)
            cost_list.append(tt)
            # print(cost_list)
            total_list.append(cost_list)
            # print("Dataset",name_p[d_i],total_list)

            # dataset_cost10_list1 = np.array(dataset_cost10_list1)
            # min_cost1 = np.min(dataset_cost10_list1)
            # max_cost1 = np.max(dataset_cost10_list1)
            # mean_cost1 = np.mean(dataset_cost10_list1)
            # std_cost1 = np.std(dataset_cost10_list1)
            # cost_list1.append(min_cost1)
            # cost_list1.append(max_cost1)
            # cost_list1.append(mean_cost1)
            # cost_list1.append(std_cost1)
            # cost_list1.append(tt1)
            # total_list1.append(cost_list1)

    print(total_list)
    # print(total_list1)
    #

    # df = pd.DataFrame(total_list, columns=['dataset', 'min_cost', 'max_cost', 'mean_cost', 'std_cost','time'])
    # df1 = pd.DataFrame(total_list1, columns=['dataset', 'min_cost', 'max_cost', 'mean_cost', 'std_cost','time'])
    #
    # df.to_excel("LLOYD_IJCAI.xlsx", index=False)
    # df1.to_excel("ls_orgl.xlsx", index=False)