import numpy as np
import random
import math
import time
from sklearn.neighbors import KDTree
from sklearn.neighbors import BallTree
from sklearn.datasets import make_blobs
from itertools import combinations
from functools import partial
from multiprocessing.pool import Pool
from multiprocessing import RawArray
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import pandas as pd
from sklearn.metrics.pairwise import pairwise_distances

# data, _ =make_blobs(n_samples=500000,n_features=2,cluster_std=0.1)

# read data
MAX_PROCESSOR = 8
var_dict = {}

r_tot_c = 0


# print(data.shape, W.sum())


def f_swap(args):
    if (args[0] == "SWAP"):
        return SWAP_ORG(*args[1:])
    if (args[0] == "SEARCH"):
        return NEIGHBOR_SEARCH(*args[1:])


def SWAP_ORG(pair, init):
    id_out = pair[0]
    id_in = pair[1]
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    w = np.frombuffer(var_dict['w']).reshape(var_dict['w_shape'])
    init_new = init.copy()
    init_new[id_out] = id_in
    init_new_np = data1[init_new]
    init_new_np = np.array(init_new_np.copy())
    Tree1 = BallTree(init_new_np, leaf_size=40)
    dist1, _ = Tree1.query(data1, k=1)
    dist1 = dist1[:, 0] ** 2
    cost1 = (dist1 * w).sum()
    return cost1, id_out, id_in


def fast_local_search(data, init, ybxl, r, t, k, W):
    print("-----------------------------------Start Original Local Search------------------------------------")
    data_shape = (data.shape[0], data.shape[1])
    data_g = RawArray('d', data_shape[0] * data_shape[1])
    w_shape = (W.shape[0])
    w_g = RawArray('d', w_shape)
    w_np = np.frombuffer(w_g).reshape(w_shape)
    data_np = np.frombuffer(data_g).reshape(data_shape)
    np.copyto(data_np, data)
    np.copyto(w_np, W)
    pool1 = Pool(processes=MAX_PROCESSOR, initializer=init_worker, initargs=(data_g, data_shape, w_g, w_shape))

    inf = 1000000000000000000000000000000000000000005
    init_np = data.copy()[init]
    Tree = BallTree(init_np.copy(), leaf_size=40)
    dist, ind = Tree.query(data.copy(), k=t)
    ind = np.array(ind, dtype=int)
    dist = dist[:, 0] ** 2
    # dist = dist * W
    # print("Check",dist.shape)
    nearest = []
    cost_glob = inf

    # print(second_nearest)
    for i in range(0, k):
        id_i = np.argwhere(ind[:, 0] == i)
        id_i = id_i[:, 0]
        nearest.append(id_i)
    prob = (dist.copy() * W) / ((dist.copy() * W).sum())
    cost_now = (dist * W).sum()
    id_range = [i for i in range(0, data.shape[0])]
    fail = 0
    while (1):
        # print("START Regular Local Search")
        for i in range(0, r):
            nextpoint = np.random.choice(id_range, 1, replace=False, p=prob)[0]
            Min = inf
            # for j in range(0,k):
            #     #construct new clutering centers
            #     init_new = init.copy()
            #     init_new[j] = nextpoint
            #     #calculate the new distances
            #     id_j = nearest[j]
            #     # cost_delta = 0
            #     dist_s = dist.copy()
            #     for p in range(0,len(id_j)):
            #         dist_new = ((data[id_j[p]] - data[nextpoint]) ** 2).sum()
            #         if(dist_new<second_nearest[id_j[p]]):
            #             # cost_delta += dist_new
            #             dist_s[id_j[p]] = dist_new
            #         else:
            #             dist_s[id_j[p]] = second_nearest[id_j[p]]
            #     # cost_org = dist[id_j].sum()
            #     # cost_new = cost_now - cost_org + cost_delta
            #     cost_new = (dist_s.copy() * W).sum()
            #     if(cost_new < Min):
            #         Min = cost_new
            #         init_f = init_new.copy()
            pairs = [[j, nextpoint] for j in range(0, k)]
            r_tot = list(pool1.imap(f_swap, [("SWAP", pairs[j], init.copy()) for j in range(0, len(pairs))]))
            for j in range(0, len(r_tot)):
                if (r_tot[j][0] < Min):
                    Min = r_tot[j][0]
                    id_out = r_tot[j][1]
                    id_in = r_tot[j][2]

            # print(Min, cost_now)

            if (Min < cost_now):
                init[id_out] = id_in
                init_np = np.array(data[init])
                Tree2 = BallTree(init_np.copy(), leaf_size=40)
                dist, _ = Tree2.query(data.copy(), k=1)
                dist = dist[:, 0] ** 2
                cost_now = (dist * W).sum()
                prob = (dist.copy() * W) / ((dist.copy() * W).sum())
                # print("Round", i, "Has A Swap",cost_now)

        # print("Enter the Lloyd")
        init_L = (data.copy())[init]
        TreeL = KDTree(init_L.copy(), leaf_size=40)
        _, indL = TreeL.query(data.copy(), k=1)
        indL = indL[:, 0]
        while (1):
            # Create new centers
            init_L_new = init_L.copy()
            for i1 in range(0, k):
                id_i = np.argwhere(indL == i1)
                id_i = id_i[:, 0]
                W_repeat = np.tile(W[id_i], (data.shape[1], 1))
                W_repeat = W_repeat.transpose()
                init_L_new[i1] = np.sum(data[id_i] * W_repeat, axis=0) / W[id_i].sum()
                # print("Check",init)
            # Find the nearest data points to approximate new centers
            # print("CHECK",init_L_new)
            TreeL1 = KDTree(data.copy(), leaf_size=40)
            _, indL1 = TreeL1.query(init_L_new, k=1)
            # print("CHECK",init_L_new.copy())
            indL1 = indL1[:, 0]
            # print("Check For New Centers", indL1)
            init_temp = data[indL1]
            # Recalculate the cost
            TreeL = KDTree(init_temp.copy(), leaf_size=40)
            distL, indL = TreeL.query(data.copy(), k=1)
            indL = indL[:, 0]
            distL = distL[:, 0] ** 2
            distL = distL * W
            if (distL.sum() < cost_now):
                cost_now = distL.sum()
                init = indL1.copy()
                init_L = init_temp.copy()
                # print("Lloyd Has A Swap", cost_now)
            else:
                # print("No Better",distL.sum())
                break

        ct = 0
        # print("Enter the Nearest Neighbot Search")
        # print("Check For Init", init)
        init_new_np = data[init.copy()]
        k_list = [10, 20, 30, 50, 100]
        k_id = 0
        dist_f = 0
        while (1):
            if (k_id >= len(k_list)):
                break
            Tree2 = KDTree(data.copy(), leaf_size=40)
            _, ind2 = Tree2.query(init_new_np, k=k_list[k_id])
            # print(ind2)
            Min = 10000000000000000000000000000000005
            init_f = 0
            id_in = -1
            id_out = -1
            pair = [[i2, ind2[i2][j]] for i2 in range(0, k) for j in range(0, ind2.shape[1])]
            r_tot = list(pool1.imap(f_swap, [("SEARCH", pair[b], init.copy()) for b in range(0, len(pair))]))
            global r_tot_c
            r_tot_c = r_tot.copy()
            for i2 in range(0, len(r_tot)):
                if (r_tot[i2][0] < Min):
                    id_in = r_tot[i2][2]
                    id_out = r_tot[i2][1]
                    Min = r_tot[i2][0]
            init_f_id = init.copy()
            init_f_id[id_out] = id_in
            init_f = data[init_f_id]
            if (Min < cost_now):
                cost_now = Min
                init_new_np = init_f.copy()
                init = init_f_id.copy()
                # print("Neighbor Search Success",cost_now)
            else:
                k_id += 1

            # for i in range(0,k):
            #     for j in range(0,ind2.shape[1]):
            #         init_new_np1 = init_new_np.copy()
            #         init_new_np1[i] = data[ind2[i][j]]
            #         Tree3 = KDTree(init_new_np1, leaf_size=40)
            #         dist3, _ = Tree3.query(data,k=1)
            #         dist3 = dist3[:,0] ** 2
            #         #dist3 = dist3 * W
            #         cost_ff = (dist3 * W).sum()
            #         if(cost_ff<Min):
            #             Min = cost_ff
            #             init_f = init.copy()
            #             init_f[i] = ind2[i][j]
            #             dist_f = dist3.copy()
            # if(Min < cost_now):
            #     cost_now = Min
            #     init_new_np = data[init_f.copy()]
            #     print("Round",ct,"Has A Swap", cost_now)
            #     ct += 1
            # else:
            #     k_id += 1

        # mutation
        if (cost_now < cost_glob):
            # print("CHECK",cost_now,cost_glob)
            cost_glob = cost_now
        else:
            fail += 1
        # print("fail",fail)
        # mutation
        init_set = set(init.copy())
        al = [i for i in range(0, data.shape[0])]
        al = set(al)
        al.difference(init_set)
        al = list(al)
        rd1 = random.sample(range(0, len(al)), 1)[0]
        rd2 = random.sample(range(0, len(init)), 1)[0]
        init[rd2] = al[rd1]
        init_np = data[init.copy()]
        TreeF = BallTree(init_np, leaf_size=40)
        distF, _ = TreeF.query(data.copy(), k=1)
        distF = distF[:, 0] ** 2
        cost_now = (distF * W).sum()
        prob = (distF * W) / ((distF * W).sum())

        # print("---RESTRAT---")

        if (fail >= 0):
            return cost_glob
            print("--------------------Final Clustering Cost---------------------", cost_glob)
            break


def generate_candidate_set(center_list):
    list_f = []
    for i in range(1, len(center_list) + 1):
        for j in combinations(center_list, i):
            list_f.append(list(j))
    return list_f


def fast_local_search1(data, init, k, ybxl, r, t, c, alpha, epsilon, W):
    print("-----------------------------------Start Fast Local Search------------------------------------")
    inf = 1000000000000000000000000000000000000000005
    groups = 5
    sample_size = math.ceil(3 * k * math.log2(k / epsilon) / epsilon)
    cost_glob = inf
    id_range = [i for i in range(0, data.shape[0])]
    fail = 0
    swap_count = np.zeros(data.shape[0])
    # Parallelization

    while (1):
        init_np = data.copy()[init]
        Tree = BallTree(init_np.copy(), leaf_size=40)
        dist, ind = Tree.query(data.copy(), k=t)
        ind = np.array(ind, dtype=int)
        dist = dist[:, 0] ** 2
        dist = dist * W
        prob = dist.copy() / dist.copy().sum()
        id_range_f = [i for i in range(0, data.shape[0])]
        cost_now = dist.sum()

        # groups = math.ceil(c*alpha/(alpha-1))

        # groups = 1
        # sample_size = data.shape[0]

        print("Check Size", "Sample Size", sample_size, "Group Size", groups)
        tswap = time.time()
        for i in range(0, r):
            # print("Count",i)
            # print(init)
            # nextpoint = np.random.choice(id_range,1,replace=False,p=prob)[0]
            prob_boost = prob * 3
            p1 = np.random.rand(data.shape[0])
            pdiff = prob_boost - p1
            id_range = np.argwhere(pdiff >= 0)
            id_range = id_range[:, 0]
            if (len(id_range) < 1 or len(id_range) > 3):
                continue
            next_points = generate_candidate_set(list(id_range))
            # print("Check Swap Size", len(next_points))
            init_f_np = 0
            init_f = 0
            id_out = []
            Min = inf
            for t1 in range(0, len(next_points)):
                if (len(next_points[t1]) == 1):
                    nextpoint = next_points[t1][0]
                    for j in range(0, k):
                        # construct new clutering centers
                        init_new = init.copy()
                        init_new[j] = nextpoint
                        init_new_np = data[init_new]
                        Tree1 = BallTree(init_new_np, leaf_size=40)
                        for s in range(0, groups):
                            id_sample = random.sample(range(0, data.shape[0]), sample_size)
                            id_sample = np.array(id_sample, dtype=int)
                            data_s = data[id_sample]
                            W_s = W.copy()[id_sample]
                            dist1, _ = Tree1.query(data_s, k=1)
                            dist1 = dist1[:, 0] ** 2
                            dist1 = dist1 * W_s
                            cost1 = dist1.sum()
                            if (cost1 < Min):
                                id_out = []
                                id_out.append(init[j])
                                Min = cost1
                                init_f = init_new.copy()
                                init_f_np = init_new_np.copy()
                elif (len(next_points[t1]) == 2):
                    nextpoint = next_points[t1]
                    for j in range(0, k):
                        for j1 in range(j + 1, k):
                            init_new = init.copy()
                            init_new[j] = nextpoint[0]
                            init_new[j1] = nextpoint[1]
                            init_new_np = data[init_new]
                            Tree1 = BallTree(init_new_np, leaf_size=40)
                            for s in range(0, groups):
                                id_sample = random.sample(range(0, data.shape[0]), sample_size)
                                id_sample = np.array(id_sample, dtype=int)
                                data_s = data[id_sample]
                                W_s = W.copy()[id_sample]
                                dist1, _ = Tree1.query(data_s, k=1)
                                dist1 = dist1[:, 0] ** 2
                                dist1 = dist1 * W_s
                                cost1 = dist1.sum()
                                if (cost1 < Min):
                                    id_out = []
                                    id_out.append(init[j])
                                    id_out.append(init[j1])
                                    Min = cost1
                                    init_f = init_new.copy()
                                    init_f_np = init_new_np.copy()
                else:
                    nextpoint = next_points[t1]
                    for j in range(0, k):
                        for j1 in range(j + 1, k):
                            for j2 in range(j1 + 1, k):
                                init_new = init.copy()
                                init_new[j] = nextpoint[0]
                                init_new[j1] = nextpoint[1]
                                init_new[j2] = nextpoint[2]
                                init_new_np = data[init_new]
                                Tree1 = BallTree(init_new_np, leaf_size=40)
                                for s in range(0, groups):
                                    id_sample = random.sample(range(0, data.shape[0]), sample_size)
                                    id_sample = np.array(id_sample, dtype=int)
                                    data_s = data[id_sample]
                                    W_s = W.copy()[id_sample]
                                    dist1, _ = Tree1.query(data_s, k=1)
                                    dist1 = dist1[:, 0] ** 2
                                    dist1 = dist1 * W_s
                                    cost1 = dist1.sum()
                                    if (cost1 < Min):
                                        id_out = []
                                        id_out.append(init[j])
                                        id_out.append(init[j1])
                                        id_out.append(init[j2])
                                        Min = cost1
                                        init_f = init_new.copy()
                                        init_f_np = init_new_np.copy()
                # Check if it is a real swap
            # print("Checklen",len(next_points))
            Tree2 = BallTree(init_f_np, leaf_size=50)
            dist1, _ = Tree2.query(data.copy(), k=1)
            dist1 = dist1[:, 0] ** 2
            dist1 = dist1 * W
            cost_next = dist1.sum()
            if (cost_next < cost_now):
                for m in range(0, len(id_out)):
                    swap_count[id_out[m]] -= 1
                cost_now = cost_next
                prob = dist1.copy() / dist1.copy().sum()
                init = init_f.copy()
                print("Round", i, "Has A Swap", cost_now)

        ct = 0
        # Lloyd Type Search
        tswap1 = time.time()
        print("Local Search Takes Time", tswap1 - tswap)
        print("Enter the Lloyd")
        tlloyd = time.time()
        init_L = data[init]
        TreeL = BallTree(init_L.copy(), leaf_size=40)
        _, indL = TreeL.query(data.copy(), k=1)
        indL = indL[:, 0]
        while (1):
            # Create new centers
            init_L_new = init_L.copy()
            for i1 in range(0, k):
                id_i = np.argwhere(indL == i1)
                id_i = id_i[:, 0]
                W_repeat = np.tile(W[id_i], (data.shape[1], 1))
                W_repeat = W_repeat.transpose()
                init_L_new[i1] = np.sum(data[id_i] * W_repeat, axis=0) / W[id_i].sum()
            # Find the nearest data points to approximate new centers
            # print(init_L_new)
            TreeL1 = BallTree(data.copy(), leaf_size=40)
            _, indL1 = TreeL1.query(init_L_new, k=1)
            indL1 = indL1[:, 0]
            # print("Check For New Centers", indL1)
            init_temp = data[indL1]
            # Recalculate the cost
            TreeL = BallTree(init_temp.copy(), leaf_size=40)
            distL, indL = TreeL.query(data.copy(), k=1)
            indL = indL[:, 0]
            distL = distL[:, 0] ** 2
            distL = distL * W
            if (distL.sum() < cost_now):
                cost_now = distL.sum()
                init = indL1.copy()
                init_L = init_temp.copy()
                print("Lloyd Has A Swap", cost_now)
            else:
                print("No Better", distL.sum())
                break
        tlloyd1 = time.time()
        print("Lloyd Takes Time", tlloyd1 - tlloyd)

        tneighbor = time.time()
        print("Enter the Nearest Neighbot Search")
        # print("Check For Init", init)
        init_new_np = data[init]
        k_list = [10, 20, 30, 50]
        k_id = 0
        dist_f = 0
        while (1):
            if (k_id >= len(k_list)):
                break
            Tree2 = BallTree(data.copy(), leaf_size=40)
            _, ind2 = Tree2.query(init_new_np, k=k_list[k_id])
            # print(ind2)
            Min = 10000000000000000000000000000000005
            init_f = 0
            init_f_id = 0
            id_out = -1
            id_int = -1
            for i2 in range(0, k):
                for j in range(0, ind2.shape[1]):
                    init_f_temp = init.copy()
                    init_f_temp[i2] = ind2[i2][j]
                    init_new_np1 = init_new_np.copy()
                    init_new_np1[i2] = data[ind2[i2][j]]
                    Tree3 = BallTree(init_new_np1, leaf_size=40)
                    dist3, _ = Tree3.query(data, k=1)
                    dist3 = dist3[:, 0] ** 2
                    dist3 = dist3 * W
                    if (dist3.sum() < Min):
                        id_in = ind2[i2][j]
                        id_out = init[i2]
                        Min = dist3.sum()
                        init_f = init_new_np1.copy()
                        init_f_id = init_f_temp.copy()
                        dist_f = dist3.copy()
            if (Min < cost_now):
                swap_count[id_out] -= 1
                cost_now = Min
                init_new_np = init_f.copy()
                init = init_f_id.copy()
                print("Round", ct, "Has A Swap", cost_now)
                ct += 1
            else:
                k_id += 1

        tneighbor1 = time.time()
        print("Neighbor Search Takes Time", tneighbor1 - tneighbor)

        swap_count[init] += 1
        # mutation
        for i3 in range(0, len(init)):
            swap_count[init[i3]] += 1
        if (cost_now < cost_glob):
            cost_glob = cost_now
        else:
            fail += 1

        score_id = np.argsort(-swap_count)[0:k]
        score_id_set = set(score_id)
        org_set = set(init.copy())
        score_id_set = score_id_set.difference(org_set)
        if (len(score_id_set) < 1):
            print("Random Mutation")
            prob_mu = dist_f.copy() / dist_f.copy().sum()
            # print(prob_mu.shape)
            mu_id = random.sample(range(0, k), 1)[0]
            swap_id = np.random.choice(id_range_f, 1, replace=False, p=prob_mu)[0]
            init[mu_id] = swap_id
        else:
            print("Score Mutation")
            mu_id = random.sample(range(0, k), 1)[0]
            swap_id = (list(score_id_set))[0]
            init[mu_id] = swap_id
        if (fail >= 0):
            print("---------------------Final Clustering Cost---------------------", cost_glob)
            break


def init_worker(data, data_shape, w, w_shape):
    var_dict['data'] = data
    var_dict['data_shape'] = data_shape
    var_dict['w'] = w
    var_dict['w_shape'] = w_shape


def swap_t(args):
    if (args[0] == "LOCAL_SEARCH"):
        return SWAP_SEARCH1(*args[1:])
    if (args[0] == "LOCAL_SEARCH1"):
        return SWAP_SEARCH2(*args[1:])
    if (args[0] == "NEIGHBOR_SEARCH"):
        return NEIGHBOR_SEARCH(*args[1:])


def SWAP_SEARCH1(groups, L, init):
    Min = 100000000000000000000000000000000005
    id_out = []
    id_in = []
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    j = L[0]
    nextpoint = L[1]
    group = [groups[i][0] for i in range(0, len(groups))]
    W1 = [groups[i][1] for i in range(0, len(groups))]
    init_new = init.copy()
    for i in range(0, len(j)):
        init_new[j[i]] = nextpoint[i]
    init_new_np = data1[init_new]
    Tree1 = BallTree(init_new_np, leaf_size=40)
    for s in range(0, len(group)):
        data_s = group[s]
        W_s = W1[s]
        dist1, _ = Tree1.query(data_s, k=1)
        dist1 = dist1[:, 0] ** 2
        dist1 = dist1 * W_s
        cost1 = dist1.sum()
        if (cost1 < Min):
            Min = cost1
            id_out = []
            id_out.append(j)
            id_in = []
            id_in.append(nextpoint)
    return Min, id_out, id_in


def SWAP_SEARCH2(groups, sample_size, L, init):
    Min = 100000000000000000000000000000000005
    id_out = []
    id_in = []
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    w = np.frombuffer(var_dict['w']).reshape(var_dict['w_shape'])
    nextpoint = L[0]
    lg = len(nextpoint)
    if (lg == 1):
        for q in range(0, k):
            init_new = init.copy()
            init_new[q] = nextpoint[0]
            init_new_np = data1[init_new]
            Tree1 = BallTree(init_new_np, leaf_size=40)
            for s in range(0, groups):
                id_sample = random.sample(range(0, data1.shape[0]), sample_size)
                id_sample = np.array(id_sample, dtype=int)
                data_s = data1[id_sample]
                W_s = w.copy()[id_sample]
                dist1, _ = Tree1.query(data_s, k=1)
                dist1 = dist1[:, 0] ** 2
                dist1 = dist1 * W_s
                cost1 = dist1.sum()
                if (cost1 < Min):
                    id_out = []
                    id_in = []
                    id_out.append(q)
                    id_in.append(nextpoint[0])
                    Min = cost1
    elif (lg == 2):
        for q in range(0, k):
            for q1 in range(q + 1, k):
                init_new = init.copy()
                init_new[q] = nextpoint[0]
                init_new[q1] = nextpoint[1]
                init_new_np = data1[init_new]
                Tree1 = BallTree(init_new_np, leaf_size=40)
                for s in range(0, groups):
                    id_sample = random.sample(range(0, data1.shape[0]), sample_size)
                    id_sample = np.array(id_sample, dtype=int)
                    data_s = data1[id_sample]
                    W_s = w.copy()[id_sample]
                    dist1, _ = Tree1.query(data_s, k=1)
                    dist1 = dist1[:, 0] ** 2
                    dist1 = dist1 * W_s
                    cost1 = dist1.sum()
                    if (cost1 < Min):
                        id_out = []
                        id_in = []
                        id_out.append(q)
                        id_out.append(q1)
                        id_in.append(nextpoint[0])
                        id_in.append(nextpoint[1])
                        Min = cost1
    else:
        for q in range(0, k):
            for q1 in range(q + 1, k):
                for q2 in range(q1 + 1):
                    init_new = init.copy()
                    init_new[q] = nextpoint[0]
                    init_new[q1] = nextpoint[1]
                    init_new[q2] = nextpoint[2]
                    init_new_np = data1[init_new]
                    Tree1 = BallTree(init_new_np, leaf_size=40)
                for s in range(0, groups):
                    id_sample = random.sample(range(0, data1.shape[0]), sample_size)
                    id_sample = np.array(id_sample, dtype=int)
                    data_s = data1[id_sample]
                    W_s = w.copy()[id_sample]
                    dist1, _ = Tree1.query(data_s, k=1)
                    dist1 = dist1[:, 0] ** 2
                    dist1 = dist1 * W_s
                    cost1 = dist1.sum()
                    if (cost1 < Min):
                        id_out = []
                        id_in = []
                        id_out.append(q)
                        id_out.append(q1)
                        id_out.append(q2)
                        id_in.append(nextpoint[0])
                        id_in.append(nextpoint[1])
                        id_in.append(nextpoint[2])
                        Min = cost1
    return Min, id_out, id_in


def NEIGHBOR_SEARCH(L, init):
    i = L[0]
    j = L[1]
    data1 = np.frombuffer(var_dict['data']).reshape(var_dict['data_shape'])
    w = np.frombuffer(var_dict['w']).reshape(var_dict['w_shape'])
    init_f_temp = init.copy()
    init_f_temp[i] = j
    init_new_np1 = data1[init_f_temp]

    Treeq = BallTree(init_new_np1, leaf_size=40)
    dist, _ = Treeq.query(data1, k=1)
    dist = dist[:, 0] ** 2
    cost_f = (dist * w).sum()
    return cost_f, i, j, dist


def kmeans_plus(data, k):
    init = []
    prob = 0
    id_range = [i for i in range(0, data.shape[0])]
    for i in range(0, k):
        if (i == 0):
            rid = random.sample(range(0, data.shape[0]), 1)[0]
            init.append(rid)
        else:
            nextpoint = np.random.choice(id_range, p=prob, size=1, replace=False)[0]
            init.append(nextpoint)

        init_id = np.array(init.copy(), dtype=int)
        init_np = data[init_id]
        Tree = BallTree(init_np.copy(), leaf_size=40)
        dist, _ = Tree.query(data.copy(), k=1)
        dist = dist[:, 0] ** 2
        prob = dist.copy() / (dist.copy()).sum()

    return init


def Lloyd(data, k, W):
    km = KMeans(n_clusters=k, init="k-means++", n_init=10, max_iter=300)
    km.fit(data)
    centers = km.cluster_centers_
    Tree = BallTree(data, leaf_size=40)
    _, ind = Tree.query(centers, k=1)
    ind = ind[:, 0]
    center_new = data[ind]
    Tree1 = BallTree(center_new, leaf_size=40)
    dist, _ = Tree1.query(data, k=1)
    dist = dist[:, 0] ** 2
    print("------------------------Lloyd--------------------------", (dist * W).sum())
    return (dist * W).sum()


def fast_local_search2(data, init, k, ybxl, r, t, c, alpha, epsilon, W, single, multi, triple, eta, rounds):
    # print("-----------------------------------Start Fast Local Search------------------------------------")
    data_shape = (data.shape[0], data.shape[1])
    data_g = RawArray('d', data_shape[0] * data_shape[1])
    data_np = np.frombuffer(data_g).reshape(data_shape)
    w_shape = (W.shape[0])
    w_g = RawArray('d', w_shape)
    w_np = np.frombuffer(w_g).reshape(w_shape)
    np.copyto(data_np, data)
    np.copyto(w_np, W)
    inf = 1000000000000000000000000000000000000000005
    groups = 5
    # eta = 0.75
    # epsilon = 0.5
    sample_size = math.ceil(3 * k * math.log2(k / eta) / epsilon)
    sample_size = min(data.shape[0], sample_size)
    cost_glob = inf
    id_range = [i for i in range(0, data.shape[0])]
    fail = 0
    swap_count = np.zeros(data.shape[0])
    # Parallelization
    pool1 = Pool(processes=MAX_PROCESSOR, initializer=init_worker, initargs=(data_g, data_shape, w_g, w_shape))
    count_tot = 0
    init_ff = -1
    while (1):
        init_np = data[init]
        # Tree = BallTree(init_np.copy(), leaf_size=40)
        # dist, ind = Tree.query(data.copy(), k=t
        # ind = np.array(ind, dtype=int)

        dist = (pairwise_distances(data, init_np, metric="euclidean")) ** 2
        dist = np.min(dist, axis=1)
        dist = dist * W

        prob = dist.copy() / dist.copy().sum()
        id_range_f = [i for i in range(0, data.shape[0])]
        cost_now = dist.sum()
        print("Check Init Cost", cost_now)
        for i in range(0, rounds):
            prob_boost = prob * 1.5
            p1 = np.random.rand(data.shape[0])
            pdiff = prob_boost - p1
            id_range = np.argwhere(pdiff >= 0)
            id_range = id_range[:, 0]
            if (len(id_range) < 1 or len(id_range) > 2):
                continue
            next_points = generate_candidate_set(list(id_range))
            id_out = []
            id_in = []
            Min = inf
            sample_groups = []
            for s in range(0, groups):
                id_sample = random.sample(range(0, data.shape[0]), sample_size)
                id_sample = np.array(id_sample, dtype=int)
                data_s = data[id_sample]
                w_s = (W.copy())[id_sample]
                sample_groups.append([data_s, w_s])

            # Plan A
            pair = []
            tpair = time.time()
            for f in range(0, len(next_points)):
                if (len(next_points[f]) == 1):
                    pair = pair + [[[single[n1]], next_points[f]] for n1 in range(0, len(single))]
                elif (len(next_points[f]) == 2):
                    pair = pair + [[multi[n1], next_points[f]] for n1 in range(0, len(multi))]
                else:
                    pair = pair + [[triple[n1], next_points[f]] for n1 in range(0, len(triple))]
            tpair1 = time.time()
            # print("Pair Construction Time", tpair1 - tpair)
            r_tot = list(pool1.imap(swap_t, [("LOCAL_SEARCH", sample_groups, pair[b], init.copy()) for b in
                                             range(0, len(pair))]))

            # Plan B
            # pair = [[next_points[b]] for b in range(0, len(next_points))]
            # r_tot = list(pool1.imap(swap_t,[("LOCAL_SEARCH1", groups, sample_size, pair[b], init.copy()) for b in range(0,len(pair))]))

            # r_tot = []
            # for qq in range(0,len(pair)):
            #     r_tot.append(list(SWAP_SEARCH1(data.copy(),sample_groups,pair[qq],init.copy())))

            for f in range(0, len(r_tot)):
                if (r_tot[f][0] < Min):
                    Min = r_tot[f][0]
                    id_out = r_tot[f][1]
                    id_in = r_tot[f][2]
            init_f = init.copy()
            for n1 in range(0, len(id_in)):
                init_f[id_out[n1]] = id_in[n1]
            init_f_np = (data.copy())[init_f]

            # print("One Round Swap Length", len(next_points),"Takes Time", tswap2 - tswap1)
            # Tree2 = BallTree(init_f_np, leaf_size=50)
            # dist1, _ = Tree2.query(data.copy(), k=1)
            # dist1 = dist1[:, 0] ** 2
            # dist1 = dist1 * W

            dist1 = (pairwise_distances(data, init_f_np, metric="euclidean")) ** 2
            dist1 = np.min(dist1, axis=1)
            dist1 = dist1 * W

            cost_next = dist1.sum()
            if (cost_next < cost_now * (1 - 1 / (100 * k))):
                for m in range(0, len(id_out)):
                    swap_count[id_out[m]] -= 1
                cost_now = cost_next
                prob = dist1.copy() / dist1.copy().sum()
                init = init_f.copy()
                print("Round", i, "Has A Swap", cost_now)

        init_L = data[init]
        # TreeL = BallTree(init_L.copy(), leaf_size=40)
        # _, indL = TreeL.query(data.copy(), k=1)
        # indL = indL
        distL = (pairwise_distances(data, init_L, metric="euclidean")) ** 2
        indL = np.argmin(distL, axis=1)

        while (1):
            # Create new centers
            init_L_new = init_L.copy()
            for i1 in range(0, k):
                id_i = np.argwhere(indL == i1)
                id_i = id_i[:, 0]
                #W_repeat = np.tile(W[id_i], (data.shape[1], 1))
                #W_repeat = W_repeat.transpose()
                init_L_new[i1] = np.sum(data[id_i], axis=0) / len(id_i)
            # Find the nearest data points to approximate new centers
            # print(init_L_new)
            # TreeL1 = BallTree(data, leaf_size=40)
            # _, indL1 = TreeL1.query(init_L_new, k=1)
            # indL1 = indL1[:, 0]

            distL1 = (pairwise_distances(init_L_new, data, metric="euclidean")) ** 2
            indL1 = np.argmin(distL1, axis=1)

            # print("Check For New Centers", indL1)
            init_temp = data[indL1]
            # Recalculate the cost

            # TreeL = BallTree(init_temp.copy(), leaf_size=40)
            # distL, indL = TreeL.query(data.copy(), k=1)
            # indL = indL[:, 0]
            # distL = distL[:, 0] ** 2
            # distL = distL * W

            distL = (pairwise_distances(data, init_temp, metric="euclidean")) ** 2
            indL = np.argmin(distL, axis=1)
            distL = np.min(distL, axis=1)

            if (distL.sum() < cost_now):
                cost_now = distL.sum()
                init = indL1.copy()
                init_L = init_temp.copy()
                print("Lloyd Has A Swap", cost_now)
            else:
                print("No Better", distL.sum())
                break

        tlloyd1 = time.time()
        # print("Lloyd Takes Time", tlloyd1 - tlloyd)

        # tneighbor = time.time()
        # # print("Enter the Nearest Neighbot Search")
        # # print("Check For Init", init)
        # init_new_np = data[init]
        # k_list = [10, 20, 30, 50]
        # k_id = 0
        # dist_f = 0

        # # TreeS = BallTree(init_new_np, leaf_size=40)
        # # distS, indS = TreeS.query(data.copy(), k=k)
        # # distS = distS ** 2

        # while (1):
        #     if (k_id >= len(k_list)):
        #         break
        #     Tree2 = BallTree(data.copy(), leaf_size=40)
        #     _, ind2 = Tree2.query(init_new_np, k=k_list[k_id])
        #     # print(ind2)
        #     Min = 10000000000000000000000000000000005
        #     init_f = 0
        #     init_f_id = 0
        #     id_out = -1
        #     id_in = -1
        #     pair = [[i2, ind2[i2][j], ] for i2 in range(0, k) for j in range(0, ind2.shape[1])]
        #     r_tot = list(pool1.imap(swap_t, [("NEIGHBOR_SEARCH", pair[b], init.copy()) for b in range(0, len(pair))]))
        #     for i2 in range(0, len(r_tot)):
        #         if (r_tot[i2][0] < Min):
        #             id_in = r_tot[i2][2]
        #             id_out = r_tot[i2][1]
        #             Min = r_tot[i2][0]
        #             dist_f = r_tot[i2][3]
        #     init_f_id = init.copy()
        #     init_f_id[id_out] = id_in
        #     init_f = data[init_f_id]

        #     # for i2 in range(0,k):
        #     #     dist_temp = (distS.copy())[:,0]
        #     #     #print("i2", i2, "Check_Cost",dist_temp.sum())
        #     #     affect_id = np.argwhere(indS[:,0] == i2)[:,0]
        #     #     dist_aff = (distS.copy()[affect_id])[:,1]
        #     #     dist_temp[affect_id] = dist_aff
        #     #     #print("Check_Cost",dist_temp.sum())
        #     #     for j in range(0,ind2.shape[1]):
        #     #         init_f_temp = init.copy()
        #     #         init_f_temp[i2] = ind2[i2][j]
        #     #         init_new_np1 = init_new_np.copy()
        #     #         init_new_np1[i2] = data[ind2[i2][j]]
        #     #         #PlanB for Calculation of Cost
        #     #         # Tree3 = BallTree(init_new_np1, leaf_size=40)
        #     #         # dist3, _ = Tree3.query(data,k=1)
        #     #         # dist3 = dist3[:,0] ** 2
        #     #         # dist3 = dist3 * W

        #     #         data_q = (data[ind2[i2][j]]).reshape(1,-1)
        #     #         Treeq = BallTree(data_q.copy(), leaf_size=40)
        #     #         distq, _ = Treeq.query(data.copy(),k=1)
        #     #         distq = distq[:,0] ** 2

        #     #         dist_tempq = dist_temp.copy()
        #     #         dist_diff = distq - dist_tempq
        #     #         dist_id = np.argwhere(dist_diff<0)[:,0]
        #     #         dist_tempq[dist_id] = distq[dist_id]
        #     #         cost_f = (dist_tempq * W).sum()
        #     #         #print("Check",cost_f)

        #     #         if(cost_f<Min):
        #     #             id_in = ind2[i2][j]
        #     #             id_out = init[i2]
        #     #             Min = cost_f
        #     #             init_f = init_new_np1.copy()
        #     #             init_f_id = init_f_temp.copy()
        #     #             dist_f = dist_tempq.copy()
        #     if (Min < cost_now):
        #         swap_count[id_out] -= 1
        #         cost_now = Min
        #         init_new_np = init_f.copy()
        #         init = init_f_id.copy()
        #         print("Round", ct, "Has A Swap", cost_now)
        #         ct += 1
        #     else:
        #         k_id += 1

        # tneighbor1 = time.time()
        # print("Neighbor Search Takes Time", tneighbor1 - tneighbor)
        swap_count[init] += 1

        # cs = ['red','orange','yellow','green','cyan','blue','purple','pink','magenta','brown']
        # Treep = BallTree(init_f, leaf_size=40)
        # _, indp = Treep.query(data.copy(), k=1)
        # indp = indp[:, 0]
        # fig = plt.figure()
        # ax = Axes3D(fig)
        # ax.set_xlim(-0.2,0.6)
        # ax.set_ylim(-0.4,0.4)
        # ax.set_zlim(0,0.8)
        # for i3 in range(0,k):
        #     indi3 = np.argwhere(indp==i3)
        #     indi3 = indi3[:,0]
        #     ax.scatter((data[indi3])[:,0],(data[indi3])[:,1],(data[indi3])[:,2], c=cs[i3])
        # plt.show()

        # print("Centers",init_f)

        # mutation

        if (cost_now < cost_glob):
            print("-------ClusteringCost---------", cost_glob)
            init_ff = init.copy()
            for i3 in range(0, len(init)):
                swap_count[init[i3]] += 1
            cost_glob = cost_now
        else:
            fail += 1

        # score_id = np.argsort(-swap_count)[0:k]
        # score_id_set = set(score_id)
        # org_set = set(init.copy())
        # score_id_set = score_id_set.difference(org_set)
        # Treeh = BallTree(init_new_np.copy(), leaf_size=40)
        # _, indh = Treeh.query(data, k=1)
        # indh = indh[:, 0]
        # std_max = -1
        # std_min = inf
        # ind_std = -1
        # id_min = - 1
        # id_max = -1
        # for i3 in range(0, k):
        #     id3 = np.argwhere(indh == i3)
        #     id3 = id3[:, 0]
        #     data_m = data[id3]
        #     std = np.var(data_m, axis=0)
        #     std_t = std.sum()
        #     # print("STD",std_t)
        #     if (std_t > std_max):
        #         std_max = std_t
        #         ind_std = id3.copy()
        #         id_max = i3
        #     if (std_t < std_min):
        #         std_min = std_t
        #         id_min = i3

        # fig1 = plt.figure()
        # ax1 = Axes3D(fig1)
        # ax1.set_xlim(-0.2,0.6)
        # ax1.set_ylim(-0.4,0.4)
        # ax1.set_zlim(0,0.8)
        # ax1.scatter((data[ind_std])[:,0],(data[ind_std])[:,1],(data[ind_std])[:,2], c=cs[0])
        # plt.show()

        # if (len(score_id_set) < 5 or count_tot <= 2):
        #     # print("Random Mutation")

        #     # Plan C

        #     init_ex = init.copy()
        #     init_ex = list(init_ex)
        #     t_s1 = math.ceil(k * math.log2(k / 0.5) / epsilon)
        #     t_s1 = min(math.floor(2 * k), t_s1)
        #     rd_s1 = random.sample(range(0, data.shape[0]), t_s1)
        #     # init_ex = init_ex + list(rd_s1)
        #     init_ex = list(rd_s1)
        #     l_up = len(init_ex + list(init.copy()))
        #     init_ex = np.array(init_ex, dtype=int)
        #     init_npe = data[init_ex]

        #     Tree_e = BallTree(init_npe.copy(), leaf_size=40)
        #     dist_e, ind_e = Tree_e.query(data.copy(), k=1)
        #     dist_e = dist_e[:, 0] ** 2
        #     dist_e = dist_e * W
        #     ind_e = ind_e[:, 0]
        #     cost_ex = (dist_e * W).sum()

        #     s_temp = set(list(init.copy()) + list(init_ex.copy()))
        #     s_temp_list = np.array(list(s_temp), dtype=int)

        #     score_f = np.zeros(l_up)
        #     # dist_ex = dist_f.copy() * W
        #     # prob_ex = dist_ex.copy() / (dist_ex.copy()).sum()
        #     # #print("CHECK",id_range.shape,prob_ex.shape)
        #     # for i in range(0,k):
        #     #     nextpoint = np.random.choice(id_range_f,p=prob_ex,size=1,replace=False)[0]
        #     #     init_ex.append(nextpoint)

        #     #     init_ide = np.array(init_ex.copy(),dtype=int)
        #     #     init_npe = data[init_ide]
        #     #     Tree_e = BallTree(init_npe.copy(), leaf_size=40)
        #     #     dist_e, ind_e = Tree_e.query(data.copy(),k=1)
        #     #     dist_e = dist_e[:,0] ** 2
        #     #     dist_e = dist_e * W
        #     #     ind_e = ind_e[:,0]
        #     #     prob_ex = dist.copy() / (dist.copy()).sum()
        #     # cost_ex = dist_e.sum()

        #     # print("CHECK for list",init,init_ex)
        #     # Calculate the weights for mutation

        #     # print("CHECK", cost_ex)
        #     # for i5 in range(0, len(init_ex)):
        #     #     if (i5 < k):
        #     #         id5 = (np.argwhere(indp == i5))[:, 0]
        #     #         id6 = (np.argwhere(ind_e == i5))[:, 0]
        #     #         idi5 = (np.argwhere(s_temp_list == init[i5]))[:, 0]
        #     #         cost5 = (dist_f[id5] * W[id5]).sum()
        #     #         cost6 = (dist_e[id6] * W[id6]).sum()
        #     #         idi6 = (np.argwhere(s_temp_list == init_ex[i5]))[:, 0]
        #     #         score_f[idi5] += 1 * cost5 / cost_now
        #     #         score_f[idi6] += 3 * cost6 / cost_ex
        #     #     else:
        #     #         id6 = (np.argwhere(ind_e == i5))[:, 0]
        #     #         cost6 = (dist_e[id6] * W[id6]).sum()
        #     #         idi6 = (np.argwhere(s_temp_list == init_ex[i5]))[:, 0]
        #     #         score_f[idi6] += 3 * cost6 / cost_ex
        #     # scores3 = score_f.copy()
        #     # # normalization
        #     # id_score = np.argsort(-scores3)
        #     # init_ex = np.array(init_ex, dtype=int)
        #     # # print("OLD INIT",scores1)
        #     # init_check = init.copy()
        #     # # print("CHECK", s_temp_list)
        #     # init = (s_temp_list.copy())[id_score[0:k]]
        #     # # print("NEW INIT",scores2)

        #     # print("-----------------Difference---------------", len(set(init_check).difference(set(init))))

        #     # prob_boost = np.ones(len(init)) * (1/len(init)) * 2
        #     # p1 = np.random.rand(len(init))

        #     # pdiff = prob_boost - p1

        #     # id_range = np.argwhere(pdiff>=0)
        #     # id_range = id_range[:,0]
        #     # print("------Mutation Number------",len(id_range))
        #     # if(len(id_range)==0):
        #     #     rd1 = random.sample(range(0,len(init)),1)[0]
        #     #     rd2 = random.sample(range(0,data.shape[0]),1)[0]
        #     #     init[rd1] = rd2
        #     # else:
        #     #     prob_mu = 0.5 * dist_f.copy() / dist_f.copy().sum() + 0.5 * np.ones(data.shape[0]) * (1 / data.shape[0])
        #     #     #print(prob_mu.shape)
        #     #     mu_id = []
        #     #     for i3 in range(0,len(id_range)):
        #     #         swap_id = np.random.choice(id_range_f,1,replace=False,p=prob_mu)[0]
        #     #         mu_id.append(swap_id)
        #     #     mu_id = np.array(mu_id,dtype=int)
        #     #     init[id_range] = mu_id
        #     # Plan B: Close a small cluster and split a huge cluster

        #     # Plan B
        #     # rd1 = random.sample(range(0,len(ind_std)),2)
        #     # rd2 = rd1[1]
        #     # rd1 = rd1[0]
        #     # data_q = data[ind_std[rd1]]
        #     # data_q = data_q.reshape(1,-1)
        #     # Treeq = BallTree(data_q,leaf_size=40)
        #     # distq, _ = Treeq.query(data,k=1)
        #     # distq = distq[:,0]
        #     # #rd2 = (np.argsort(-distq))[0]
        #     # init[id_min] = ind_std[rd1]
        #     # init[id_max] = rd2

        #     # init_test = data[init]
        #     # Tree_test = BallTree(init_test,leaf_size=40)
        #     # dist_test, ind_test = Tree_test.query(data,k=1)
        #     # ind_test = ind_test[:,0]
        #     # dist_test = dist_test[:,0] ** 2
        #     # cost_test = (dist_test * W).sum()
        #     # print("NEW INIT COST",cost_test)

        #     # fig3 = plt.figure()
        #     # ax3 = Axes3D(fig3)
        #     # ax3.set_xlim(-0.2,0.6)
        #     # ax3.set_ylim(-0.4,0.4)
        #     # ax3.set_zlim(0,0.8)
        #     # for i3 in range(0,k):
        #     #     indi3 = np.argwhere(ind_test==i3)
        #     #     indi3 = indi3[:,0]
        #     #     ax3.scatter((data[indi3])[:,0],(data[indi3])[:,1],(data[indi3])[:,2], c=cs[i3])
        #     # plt.show()

        # else:
        #     # print("Score Mutation")
        #     swap_id = (list(score_id_set))[0]
        #     init[id_min] = swap_id
        # count_tot += 1
        if (fail >= 0):
            # print("CHECK", init_ff)
            init_ff = np.array(init_ff, dtype=int)
            print("---------------------Final Clustering Cost---------------------", cost_glob)
            return cost_glob
            break


if __name__ == '__main__':
    total_list = []
    total_list1 = []
    # name_p = ['rds']
    # name_p = ['iris','seeds','glass']
    # name_p = ['UK_L','HF_L', 'Abs_FL']
    # name_p = ['rds','KEGG_FL','urbanGB_L_10','rng_agr','urbanGB_L','spnet3D','syn_1E7_2_3','USC1990']
    # name_p = ['iris','seeds','glass','HCV_L','BM_FL','UK_L','HF_L', 'Abs_FL','AC_FL','GT_FL','hemi','HF_L','HTRU2_L','pr2392','Who_FL','TR_FL','SGC_FL','AC_FL']
    name_p = ['SIFT']
    for d_i in range(len(name_p)):
        for d_j in np.array([3, 5, 10]):
            cost_list = []
            cost_list1 = []
            first_name = str(name_p[d_i]) + "_k=" + str(d_j)
            print(first_name)
            cost_list.append(first_name)
            cost_list1.append(first_name)
            k = d_j
            print("k:", k)
            dataset_cost10_list = []
            dataset_cost10_list1 = []
            single = [i for i in range(0, k)]
            multi = [[i, j] for i in range(0, k) for j in range(i + 1, k)]
            triple = [[i, j, j1] for i in range(0, k) for j in range(i + 1, k) for j1 in range(j + 1, k)]
            data_path = "data/" + str(name_p[d_i])
            if (name_p[d_i] == "USC1990"):
                data_path = data_path + ".txt"
                data = np.loadtxt(data_path, delimiter=',')
            elif (name_p[d_i] == "SUSY"):
                data_path = data_path + ".csv"
                data = np.loadtxt(data_path, usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                                  delimiter=',')
            elif (name_p[d_i] == "HIGGS"):
                data_path = data_path + ".csv"
                data = np.loadtxt(data_path, delimiter=',')
                data = data[:, 1:data.shape[1]]
            elif (name_p[d_i] == "SIFT"):
                x = np.memmap("learn.bvecs", dtype='uint8', mode='r')
                d = x[:4].view('int32')[0]
                data = x.reshape(-1, d + 4)[:, 4:]
                data = np.array(data)
            else:
                data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')



            # data_org = data.copy()
            # W_org = np.ones(data_org.shape[0])
            # data, W = np.unique(data, axis=0, return_counts=True)
            W = np.ones(data.shape[0])
            tt = 0
            tt1 = 0

            mmiter = 10
            for d_k in range(mmiter):
                # data, _ = make_blobs(n_samples=1000000,n_features=20,cluster_std=1.5)

                # data = np.loadtxt('yeast.txt',usecols = (1,2,3,4,5,6,7,8))
                # data, _ = make_blobs(n_samples=1000000,n_features=40,centers=10)
                # data = np.loadtxt('SUSY.csv',usecols=(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18), delimiter=",")

                # initialization
                # init_id = random.sample(range(0,data.shape[0]),k)
                # init_id = kmeans_plus(data.copy(), k)
                # init_id = np.array(init_id, dtype=int)

                init_id = random.sample(range(0, data.shape[0]), k)
                init_id = np.array(init_id, dtype=int)

                # Lloyd(data, k, W)

                t0 = time.time()
                # data_cost = fast_local_search(data.copy(), init_id.copy(), 1/100, 400, 2, k, W.copy())
                # dataset_cost10_list1.append(data_cost_1)

                # t1 = time.time()
                # fast_local_search1(data.copy(), init_id.copy(), k, 1/100, 400, 2, 10, 2, 0.5, W.copy())

                # fast_local_search2(data.copy(), init_id.copy(), k, 1/100, 1000, 2, 10, 2, 0.5)
                # t3 = time.time()

                data_cost = fast_local_search2(data, init_id.copy(), k, 1 / 100, 10, 2, 10, 2, 0.5, W.copy(),
                                               single.copy(), multi.copy(), triple.copy(), 0.5, 400)

                # data_cost = Lloyd(data_org.copy(), k, W_org)
                t2 = time.time()
                # print(data_cost)
                dataset_cost10_list.append(data_cost)
                t3 = time.time()

                tt1 += t2 - t0
                tt += t2 - t0
                # print(t2 - t1)
                # print(t3 - t2)
            # print(dataset_cost10_list)
            # print("ORG",tt1)
            print("NEW", tt)
            tt = tt / mmiter
            tt1 = tt1 / mmiter
            dataset_cost10_list = np.array(dataset_cost10_list)
            min_cost = np.min(dataset_cost10_list)
            max_cost = np.max(dataset_cost10_list)
            mean_cost = np.mean(dataset_cost10_list)
            std_cost = np.std(dataset_cost10_list)
            cost_list.append(min_cost)
            cost_list.append(max_cost)
            cost_list.append(mean_cost)
            cost_list.append(std_cost)
            cost_list.append(tt)
            # print(cost_list)
            total_list.append(cost_list)
            # print("Dataset",name_p[d_i],total_list)

            # dataset_cost10_list1 = np.array(dataset_cost10_list1)
            # min_cost1 = np.min(dataset_cost10_list1)
            # max_cost1 = np.max(dataset_cost10_list1)
            # mean_cost1 = np.mean(dataset_cost10_list1)
            # std_cost1 = np.std(dataset_cost10_list1)
            # cost_list1.append(min_cost1)
            # cost_list1.append(max_cost1)
            # cost_list1.append(mean_cost1)
            # cost_list1.append(std_cost1)
            # cost_list1.append(tt1)
            # total_list1.append(cost_list1)

    print(total_list)
    # print(total_list1)
    #

    # df = pd.DataFrame(total_list, columns=['dataset', 'min_cost', 'max_cost', 'mean_cost', 'std_cost','time'])
    # df1 = pd.DataFrame(total_list1, columns=['dataset', 'min_cost', 'max_cost', 'mean_cost', 'std_cost','time'])
    #
    # df.to_excel("MLSORG.xlsx", index=False)
    # df1.to_excel("ls_orgl.xlsx", index=False)