import numpy as np
import random
import math
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import pairwise_distances_chunked
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree, BallTree
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import pandas as pd1
from sklearn.cluster import MiniBatchKMeans
import time
import pandas as pand


class LS(object):

    def __init__(self, n_clusters=None, rounds=400, trans=10, batch=100, total_batch=10, minibatchround=40):

        self.n_clusters_ = n_clusters
        # self.dist_oracle_ = dist_oracle.DistQueryOracle()
        self.rounds_ = rounds
        self.trans_ = trans
        self.batch_ = batch
        self.total_batch_ = total_batch
        self.minibatchround_ = minibatchround

    def CheckCost(self, X, centers):
        Tree = BallTree(centers, leaf_size=40)
        dist, _ = Tree.query(X, k=1)
        dist = dist[:, 0] ** 2
        return dist.sum()

    def minibatch_kmeans(self, data, centers, batch, rounds, error):
        for j in range(0, rounds):
            batch_points = random.sample(range(0, data.shape[0]), batch)
            points = data[batch_points]
            Tree = BallTree(centers.copy(), leaf_size=40)
            _, ind = Tree.query(points, k=1)
            ind = ind[:, 0]
            new_centroids = np.zeros([centers.shape[0], centers.shape[1]])
            for i in range(0, centers.shape[0]):
                id_i = (np.where(ind == i))[0]
                cluster_points = data[id_i]
                eta = 1.0 / cluster_points
                if len(cluster_points) > 0:
                    new_centroids[i] = np.mean(cluster_points, axis=1) + (1.0 - eta) * centers[i]
                else:
                    new_centroids[i] = centers[i].copy()
            if np.linalg.norm(new_centroids - centers) < error:
                break
            else:
                centers = new_centroids.copy()

        return centers

    def minibatch_kmeans(self, data, centers, batch, rounds, error):
        cumu_time = 0
        ratio = 10
        for j in range(0, rounds):
            if (j > 2):
                if (ratio > (1 + error)):
                    batch = math.ceil(batch * (1 + error))
                    batch = min(batch, 10 * centers.shape[0] * math.log2(centers.shape[0]))
                    batch = math.ceil(batch)
                    # print("RaiseBatch", batch)
            batch_points = random.sample(range(0, data.shape[0]), batch)
            points = data[batch_points]

            Tree = BallTree(centers, leaf_size=40)
            pd, ind = Tree.query(points, k=1)
            pd = pd[:, 0] ** 2
            ind = ind[:, 0]
            cost_old_centers = pd.sum()
            cost_now = pd.sum()
            centers_now = np.zeros([centers.shape[0], data.shape[1]])
            count = np.zeros(centers.shape[0])
            for i1 in range(0, ind.shape[0]):
                centers_now[ind[i1]] += points[i1]
                count[ind[i1]] += 1

            for i1 in range(0, centers_now.shape[0]):
                if (count[i1] == 0):
                    centers_now[i1] = centers[i1].copy()
                else:
                    centers_now[i1] = centers_now[i1] / count[i1]

            time1 = time.time()
            pd_new = (pairwise_distances(points, centers_now, metric="euclidean")) ** 2
            pd_new = np.min(pd_new, axis=1)

            cost_new_centers = pd_new.sum()

            if (cost_old_centers / cost_new_centers < 1 + error):
                # print("Final Rounds", j, "Batch", batch, "Cost", cost_old_centers - cost_new_centers)
                if (cost_old_centers - cost_new_centers > 0):
                    return centers
                else:
                    return centers_now

            time2 = time.time()
            cumu_time += time2 - time1

            if (pd_new.sum() < cost_now):
                if (j >= 2):
                    ratio = cost_now / pd_new.sum()
                centers = centers_now.copy()
        return centers

    def LS_Bandit(self, X, centers, mode):
        INF = float('inf')

        # dist = (pairwise_distances(X, centers, metric="euclidean")) ** 2
        # dist = np.min(dist, axis = 1)
        # cost_now = dist.sum()
        # nt_rg = [i for i in range(0,X.shape[0])]

        swap_mode = 2

        id_range = range(0, X.shape[0])

        for i in range(0, self.rounds_):
            # print("Rounds", i)

            # if(i > 100):
            #     swap_mode = 2

            next_id = random.sample(range(0, X.shape[0]), 1)[0]
            next_point = X[next_id]

            # dx_pair = (pairwise_distances(centers, next_point, metric="euclidean"))[:,0] ** 2
            dx_pair = np.sum(centers ** 2, axis=1) + np.sum(next_point ** 2, axis=0) - 2 * np.sum(centers * next_point,
                                                                                                  axis=1)
            dx = min(dx_pair)
            if (dx == 0):
                continue
            for j1 in range(0, self.trans_):
                y_id = random.sample(range(0, X.shape[0]), 1)[0]
                y_point = X[y_id]
                # dy_pair = (pairwise_distances(centers, y_point.reshape(1, -1), metric="euclidean"))[:,0] ** 2
                dy_pair = np.sum(centers ** 2, axis=1) + np.sum(y_point ** 2, axis=0) - 2 * np.sum(centers * y_point,
                                                                                                   axis=1)
                # print("Shape", dy_pair.shape)
                dy = min(dy_pair)

                rd_value = np.random.randn()

                if (dx == 0):
                    dx = 0.01

                if (dy / dx > rd_value):
                    dx = dy
                    dx_pair = dy_pair.copy()
                    next_point = y_point
            
            next_min = np.argmin(dx_pair)
            rd_range = []
            for i11 in range(0, self.n_clusters_):
                if i11 != next_min:
                    rd_range.append(i11)
            next_random = random.sample(range(0, len(rd_range)), 1)
            next_random = np.array(next_random, dtype=int)
            # print(len(next_random))
            next_random = (np.array(rd_range))[next_random]
            
            # "PlanB"
            # prob = dist / dist.sum()
            # next_id = np.random.choice(nt_rg,size=1,p=prob)[0]
            # next_point = X[next_id].reshape(1,-1)

            # n_used = 0
            # ref = self.batch_ * self.total_batch_
            # Solution = set([i1 for i1 in range(0, self.n_clusters_)])
            # S_mean = np.zeros(self.n_clusters_)
            # S_std = np.zeros(self.n_clusters_)
            # history_point = [[] for i in range(0,self.n_clusters_)]

            if (swap_mode == 1):
                "New Plan"
                n_used = 0
                ref = self.batch_ * self.total_batch_
                # Solution = set([i1 for i1 in range(0, self.n_clusters_)])
                Solution = set([next_min, next_random])
                S_mean = np.zeros(len(Solution))
                S_std = np.zeros(len(Solution))
                history_point = [[] for i in range(0, len(Solution))]

                flag = 0
                delta = 1 / (X.shape[0])

                # "Auxiliary Block for Identifying the best arm"

                # cost_min1 = float('inf')
                # min_id = None
                # for j1 in range(0,self.n_clusters_):
                #       centers_new = centers.copy()
                #       centers_new[j1] = next_point
                #       dist11 = (pairwise_distances(X, centers_new, metric="euclidean")) ** 2
                #       dist11 = np.min(dist11, axis = 1)
                #       cost_j1 = dist11.sum()
                #       print(cost_j1)
                #       if(cost_j1 < cost_min1):
                #           cost_min1 = cost_j1
                #           min_id = j1

                # print("True Minimum", min_id)

                # target_min = "None"
                id_range = range(0, X.shape[0])
                while (n_used < ref and len(Solution) > 1):
                    rd_batch_id = random.sample(id_range, min(self.batch_, math.ceil(0.1 * len(id_range))))
                    Slist = list(Solution)
                    Slist_np = np.array(Slist, dtype=int)
                    batch_points = X[rd_batch_id]

                    # QueryDist = BallTree(centers, leaf_size=40)
                    # pair_now, _ = QueryDist.query(batch_points, k=1)
                    # pair_now = pair_now[:, 0] ** 2

                    # pair_now = (pairwise_distances(batch_points, centers, metric='euclidean')) ** 2

                    # QueryTree = BallTree(centers, leaf_size=40)
                    # pair_now_sort, pair_now_sort_id = QueryTree.query(batch_points,k=2)
                    # pair_now_sort = pair_now_sort ** 2
                    # pair_now_sort_id = pair_now_sort_id[:,0]

                    nbrs = NearestNeighbors(n_neighbors=2).fit(centers)
                    pd, pair_now_sort_id = nbrs.kneighbors(batch_points)
                    pair_now_sort = pd ** 2
                    pair_now_sort_id = pair_now_sort_id[:, 0]

                    stop = 1
                    # print("pair_now_sort_shape", pair_now_sort_id.shape)
                    # np.sort(pair_now, axis=1)
                    # = np.argsort(pair_now, axis=1)

                    # pair_ref = (pairwise_distances(batch_points, next_point.reshape(1,-1), metric="euclidean"))[:,0] ** 2
                    pair_ref = np.sum(batch_points ** 2, axis=1) + np.sum(next_point ** 2, axis=0) - 2 * np.sum(
                        batch_points * next_point,
                        axis=1)

                    for j1 in range(0, len(Slist)):

                        # centers_id = np.ones(self.n_clusters_)
                        # centers_id[Slist[j1]] = 0
                        # centers_id_f = np.argwhere(centers_id == 1)[:,0]

                        pair_remain_min = pair_now_sort.copy()[:, 0]
                        id_affected = np.argwhere(pair_now_sort_id == Slist[j1])[:, 0]
                        pair_remain_min[id_affected] = (pair_now_sort[id_affected])[:, 1]

                        # pair_remain = (pairwise_distances(batch_points, centersr, metric='euclidean')) ** 2

                        # pair_remain_min = np.min(pair_remain, axis=1)

                        stop = 1

                        # pair_remain = (pairwise_distances(batch_points, remain_centers, metric="euclidean")) ** 2
                        # pair_remain_min = np.min(pair_remain, axis=1)

                        pair_diff = pair_ref - pair_remain_min

                        id_large = np.argwhere(pair_diff < 0)[:, 0]
                        pair_remain_min[id_large] = pair_ref[id_large]

                        pair_diff = pair_remain_min - pair_now_sort[:, 0]

                        stop = 1

                        if (flag == 0):
                            std_j1 = np.std(pair_diff)
                            if (Slist[j1] == next_min):
                                S_mean[0] = pair_diff.sum() / self.batch_
                                S_std[0] = std_j1 * math.sqrt(math.log10(1 / delta) / self.batch_)
                            else:
                                S_mean[1] = pair_diff.sum() / self.batch_
                                S_std[1] = std_j1 * math.sqrt(math.log10(1 / delta) / self.batch_)
                            if (Slist[j1] == next_min):
                                for j2 in range(0, len(pair_diff)):
                                    history_point[0].append(pair_diff[j2])
                            else:
                                for j2 in range(0, len(pair_diff)):
                                    history_point[1].append(pair_diff[j2])
                        else:
                            if (Slist[j1] == next_min):
                                history_j1 = history_point[0]
                                std_j1 = np.std(history_j1)
                                S_mean[0] = (n_used * S_mean[0] + pair_diff.sum()) / (n_used + self.batch_)
                                S_std[0] = std_j1 * (math.sqrt(math.log10(1 / delta) / (self.batch_ + n_used)))
                                history_point[0] = history_point[0] + list(pair_diff)
                            else:
                                history_j1 = history_point[1]
                                std_j1 = np.std(history_j1)
                                S_mean[1] = (n_used * S_mean[1] + pair_diff.sum()) / (n_used + self.batch_)
                                S_std[1] = std_j1 * (math.sqrt(math.log10(1 / delta) / (self.batch_ + n_used)))
                                history_point[1] = history_point[1] + list(pair_diff)

                    flag = 1
                    n_used = n_used + self.batch_

                    if (Slist_np[0] == next_min):
                        Slist_np[0] = 0
                    else:
                        Slist_np[0] = 1

                    if (Slist_np[1] == next_min):
                        Slist_np[1] = 0
                    else:
                        Slist_np[1] = 1

                    mean_current = S_mean[Slist_np]
                    std_current = S_std[Slist_np]

                    mean_plus_std = mean_current + std_current
                    target_min = min(mean_plus_std)

                    target_min = min(target_min, 0)

                    target_diff = (mean_current - std_current)

                    remain_id = np.argwhere(target_diff <= target_min)[:, 0]
                    remain_list = [Slist[remain_id[i2]] for i2 in range(0, remain_id.shape[0])]

                    Solution = set(remain_list)

            else:
                "PlanB"
                n_used = 0
                ref = self.batch_ * self.total_batch_
                # print(next_random, next_point)
                # print([next_point] + list(next_random))
                Solution = set([next_min] + list(next_random))
                S_mean = np.zeros(self.n_clusters_)
                S_std = np.zeros(self.n_clusters_)
                std_history = np.zeros(self.n_clusters_)
                # history_point = [[] for i in range(0, self.n_clusters_)]

                flag = 0
                delta = 1 / (X.shape[0])

                nbrs = NearestNeighbors(n_neighbors=2).fit(centers)

                while (n_used < ref and len(Solution) > 1):
                    # print("Check", len(Solution), n_used, ref)
                    rd_batch_id = random.sample(id_range, self.batch_)
                    # rd_batch_id = np.random.choice(id_range, size=self.batch_)
                    Slist = list(Solution)
                    Slist_np = np.array(Slist, dtype=int)
                    batch_points = X[rd_batch_id]

                    pd, pair_now_sort_id = nbrs.kneighbors(batch_points)
                    pair_now_sort = pd ** 2
                    pair_now_sort_id = pair_now_sort_id[:, 0]

                    for j1 in range(0, len(Slist)):

                        pair_ref = np.sum(batch_points ** 2, axis=1) + np.sum(next_point ** 2, axis=0) - 2 * np.sum(
                            batch_points * next_point,
                            axis=1)
                        pair_remain_min = pair_now_sort.copy()[:, 0]
                        id_affected = np.argwhere(pair_now_sort_id == Slist[j1])[:, 0]
                        pair_remain_min[id_affected] = (pair_now_sort[id_affected])[:, 1]

                        pair_diff = pair_ref - pair_remain_min

                        id_large = np.argwhere(pair_diff < 0)[:, 0]
                        pair_remain_min[id_large] = pair_ref[id_large]

                        pair_diff = pair_remain_min - pair_now_sort[:, 0]
                        stop = 1

                        if (flag == 0):
                            std_j1_sum = ((np.std(pair_diff)) ** 2) * self.batch_
                            std_j1 = math.sqrt(std_j1_sum / (self.batch_))
                            # print("Check", std_j1 - np.std(pair_diff))
                            S_mean[Slist[j1]] = pair_diff.sum() / self.batch_
                            S_std[Slist[j1]] = std_j1 * math.sqrt(2 * math.log10(1 / delta) / self.batch_)
                            std_history[Slist[j1]] = std_j1_sum
                            # for j2 in range(0,len(pair_diff)):
                            #     history_point[Slist[j1]].append(pair_diff[j2])
                        else:
                            # std_j1_sum = std_history[Slist[j1]]
                            # std_j1 = math.sqrt(std_j1_sum / (n_used))
                            # S_std[Slist[j1]] = std_j1 * (math.sqrt(math.log10(1/delta) / (self.batch_ + n_used)))

                            # print("Last", std_j1)

                            smean_old = S_mean[Slist[j1]]
                            S_mean[Slist[j1]] = (n_used * S_mean[Slist[j1]] + pair_diff.sum()) / (n_used + self.batch_)

                            std_j1_sum = std_history[Slist[j1]] + (
                                    (pair_diff - smean_old) * (pair_diff - S_mean[Slist[j1]])).sum()
                            std_j1 = math.sqrt(std_j1_sum / (n_used + self.batch_))
                            # print("Large_Check", std_j1_sum - std_history[Slist[j1]])
                            std_history[Slist[j1]] = std_j1_sum
                            S_std[Slist[j1]] = std_j1 * (
                                    0.25 * math.sqrt(math.log10(1 / delta) / (self.batch_ + n_used)))
                            stop = 1

                    flag = 1
                    n_used = n_used + self.batch_

                    mean_current = S_mean[Slist_np]
                    std_current = S_std[Slist_np]

                    # print("Check_Pre", mean_current)
                    # print("Check_After", std_current)

                    mean_plus_std = mean_current + std_current
                    target_min = min(mean_plus_std)

                    target_min = min(target_min, 0)

                    target_diff = (mean_current - std_current)

                    # print("Check the scores", target_diff)

                    remain_id = np.argwhere(target_diff <= target_min)[:, 0]
                    remain_list = [Slist[remain_id[i2]] for i2 in range(0, remain_id.shape[0])]

                    Solution = set(remain_list)

            # print("Round", i, len(Solution))

            if (target_min == 0):
                stop = 1
                continue

            Slist = list(Solution)

            if (len(Solution) == 0):
                continue

            if (mode == 'v1'):

                if (len(Solution) == 1):
                    centers_new = centers.copy()
                    centers_new[Slist[0]] = next_point
                    dist1 = (pairwise_distances(X, centers_new, metric="euclidean")) ** 2
                    dist1 = np.min(dist1, axis=1)
                    cost_j1 = dist1.sum()
                    if (cost_j1 < cost_now):
                        cost_now = cost_j1
                        centers = centers_new.copy()
                        dist = dist1.copy()
                        # print("Round",i,'Has A Swap', cost_now)
                else:
                    cost_min = float('inf')
                    centers_i = None
                    dist_i = None
                    for j1 in range(0, len(Solution)):
                        centers_new = centers.copy()
                        centers_new[Slist[j1]] = next_point
                        dist1 = (pairwise_distances(X, centers_new, metric="euclidean")) ** 2
                        dist1 = np.min(dist1, axis=1)
                        cost_j1 = dist1.sum()
                        if (cost_j1 < cost_min):
                            cost_min = cost_j1
                            centers_i = centers_new.copy()
                            dist_i = dist1.copy()
                    if (cost_min < cost_now):
                        cost_now = cost_min
                        centers = centers_i.copy()
                        dist = dist_i.copy()
                        # print("Round",i,'Has A Swap', cost_now)
            elif (mode == 'v2'):
                if (len(Solution) == 1):
                    centers[Slist[0]] = next_point
                    #print("Round", i, "Has A Swap")
                    #print("Swap Between", Slist[0], next_id)
                    # print("Round Cost", CheckCost(X, centers, 100))
                else:
                    continue
                    cost_min = float('inf')
                    centers_i = None
                    dist_i = None
                    for j1 in range(0, len(Solution)):
                        centers_new = centers.copy()
                        centers_new[Slist[j1]] = next_point
                        dist1 = (pairwise_distances(X, centers_new, metric="euclidean")) ** 2
                        dist1 = np.min(dist1, axis=1)
                        cost_j1 = dist1.sum()
                        if (cost_j1 < cost_min):
                            cost_min = cost_j1
                            centers_i = centers_new.copy()
                            dist_i = dist1.copy()
                    centers = centers_i.copy()

        "Performe 10 of the Lloyd Step"
        # print("Starting the MinibatchKmeans")
        t1 = time.time()
        # km = MiniBatchKMeans(n_clusters=self.n_clusters_, init=centers, n_init=1, max_iter=1, batch_size=1024,
        #                      compute_labels=True)
        # km.fit(X)
        # centers = km.cluster_centers_
        # mini_batch_size = 200000
        # mini_batch_size = min(math.ceil(data.shape[0] * 0.1), math.ceil(mini_batch_size))
        # centers = self.minibatch_kmeans(X, centers, mini_batch_size, self.minibatchround_, 0.05)
        km = MiniBatchKMeans(n_clusters=k, init=centers, n_init=1, max_iter=10, batch_size=1024,
                             compute_labels=False)
        km.fit(data)
        centers = km.cluster_centers_
        # print("Minibatch Kmeans Finished and take time", time.time() - t1)
        # centers = Projection(X, centers)
        return centers

    def Fast_LS(self, X, centers):
        "Strategy 1: Use more space comlexity and heap structure for fast local search implementation"
        "Preprocessing steps for storing the distance and assignment information"
        INF = 100000000000000000000000000005
        dist_2 = np.ones([X.shape[0], 2])
        dist = (pairwise_distances(X, centers, metric="euclidean")) ** 2
        ind = np.argsort(dist, axis=1)
        affected_list = [[] for i in range(0, centers.shape[0])]
        for i in range(0, X.shape[0]):
            dist_2[i][0] = dist[i][ind[i][0]]
            dist_2[i][1] = dist[i][ind[i][1]]
            affected_list[ind[i][0]].append(i)

        "Calculating the current clustering cost"
        cost_now = (dist_2[:, 0]).sum()

        "Construct the sampling distribution"
        prob_modified = dist_2[:, 0].copy() / (dist_2[:, 0].copy()).sum()
        # boosting_target = math.ceil(self.n_outliers_ * (1 + self.epsilon_))
        sample_range = [i for i in range(0, X.shape[0])]
        # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
        # prob = prob * factor_l
        # prob_id_large = (np.argwhere(prob>1))[:,0]
        # prob[prob_id_large] = 1
        # prob_modified = prob.copy() / (prob.copy()).sum()

        "Start the Local Search Process"
        # print("------------------------------Start Local Search-----------------------------", cost_now)

        centers_ff = None

        Searching = 0

        while (1):
            Searching += 1
            # print("-------------Searching Round-------------", Searching)
            "Fast Local Search"
            for i in range(0, self.rounds_):
                "preparation"
                cost_min = INF
                swap_id = None

                "Find the oversampling factor"

                "Construct the sampling distribution"

                "Sample one data point from the modified probability"
                next_point = np.random.choice(sample_range, size=1, p=prob_modified)[0]
                centers_new = (X[next_point]).reshape(1, -1)
                dist_tot_new = (pairwise_distances(X, centers_new, metric="euclidean"))[:, 0] ** 2

                "Make the comparison between the distances of nearest and the new centers"
                dist_tot_new_modified = (dist_2.copy())[:, 0]
                dist_diff = dist_tot_new_modified - dist_tot_new
                dist_large_id = (np.argwhere(dist_diff > 0))[:, 0]
                dist_tot_new_modified[dist_large_id] = dist_tot_new[dist_large_id]
                # print(dist_tot_new_modified.sum())

                "Finding the minimum id"
                next_numpy = data[next_point]
                next_numpy = next_numpy.reshape(1, -1)
                pd_k = pairwise_distances(centers, next_numpy)[:, 0]
                min_id = np.argmin(pd_k)
                rd_id = random.sample(range(0, len(centers)), 1)[0]
                # print(min_id, rd_id)

                "Try to enumerate possible swap pairs"
                for j in [min_id, rd_id]:
                    "Find the points whose closest center are swapped out"
                    "Now try to swap the j-th center out"
                    dist_temp = dist_2.copy()
                    id_affected = np.array(affected_list[j], dtype=int)

                    "Compare the distances and calculate the new cost"
                    dist_affected_modified = (dist_temp[id_affected])[:, 1]
                    pd = dist_tot_new[id_affected]
                    dist_diff = dist_affected_modified - pd
                    id_large = (np.argwhere(dist_diff > 0))[:, 0]
                    dist_affected_modified[id_large] = pd[id_large]
                    cost_new = dist_tot_new_modified.sum() - (
                        dist_tot_new_modified[id_affected]).sum() + dist_affected_modified.sum()

                    # if(i==100):
                    #     centers_temp = centers.copy()
                    #     centers_temp[j] = X[next_point]
                    #     print("CheckCost", CheckCost(X, centers_temp), cost_new)

                    "Judge if the swap is feasible"
                    if (cost_new < cost_min):
                        cost_min = cost_new
                        swap_id = j

                "Check whether the minimum cost swap is feasible"
                if (cost_min < (1 - 1 / (100 * self.n_clusters_)) * cost_now):
                    # print("Check", cost_min)
                    "Perform this swap"
                    cost_now = cost_min
                    centers[swap_id] = X[next_point]
                    "Renew the distance structures"
                    center_new = (X[next_point]).reshape(1, -1)
                    pd = (pairwise_distances(X, center_new, metric="euclidean"))[:, 0] ** 2
                    dist[:, swap_id] = pd
                    ind = np.argsort(dist, axis=1)
                    dist_2 = np.ones([X.shape[0], 2])
                    affected_list = [[] for j1 in range(0, centers.shape[0])]
                    for j1 in range(0, X.shape[0]):
                        dist_2[j1][0] = dist[j1][ind[j1][0]]
                        dist_2[j1][1] = dist[j1][ind[j1][1]]
                        affected_list[ind[j1][0]].append(j1)

                    prob_modified = dist_2[:, 0].copy() / (dist_2[:, 0].copy()).sum()
                    # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
                    # prob = prob * factor_l
                    # prob_id_large = (np.argwhere(prob>1))[:,0]
                    # prob[prob_id_large] = 1
                    # prob_modified = prob.copy() / (prob.copy()).sum()
                    # print("Round",i,"Has A Swap", cost_now)
                    # print("CheckCost", CheckCost(X, centers))
                else:
                    continue
            if (Searching >= 0):
                return centers
        mini_batch_size = math.ceil(0.01 * data.shape[0])
        mini_batch_size = min(math.ceil(data.shape[0] * 0.1), mini_batch_size)
        centers = self.minibatch_kmeans(X, centers, mini_batch_size, self.minibatchround_, 0.05)
        return centers


def Projection(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return init_temp


def Projection1(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return indL1


def CheckCost(X, centers, batch):
    # cost_tot = 0
    # for i in range(0, X.shape[0]):
    #     print(i)
    #     dy_pair = np.sum(centers ** 2, axis=1) + np.sum(X[i] ** 2, axis=0) - 2*np.sum(centers * X[i], axis=1)
    #     dy = min(dy_pair)
    #     cost_tot += dy
    count = 0
    cost_tot = 0
    while (count < X.shape[0]):
        # print(count)
        count_next = min(count + batch, X.shape[0])
        pd = pairwise_distances(X[count:count_next], centers) ** 2
        pd = np.min(pd, axis=1)
        cost_tot += pd.sum()
        count += batch
    # TreeL1 = BallTree(centers, leaf_size=40)
    # #print(len(centers))
    # distL, _ = TreeL1.query(X, k=1)
    # distL = distL[:,0] ** 2
    # cost_now = distL.sum()
    # print(cost_now)
    return cost_tot


if __name__ == '__main__':
    # a = np.random.rand(100000)
    # a = a / a.sum()
    # Target = 200
    # epsilon = 1
    # Outlier = 100
    #k_list = [10,10,20,30,50,100]
    k_list = [3,5,10]
    itera = 10
    save_name = 'ours_round50_sift.csv'
    # name_p = ['iris', 'seeds','glass','Who_FL','HCV_L','TRR_FL','urbanGB_L_10']
    # name_p = ['rds','urbanGB_L','rng','syn_1E7_2_3','USC1990']
    name_p = ['rds']
    rounds = 400
    result = pand.DataFrame(columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
    for d_i in range(len(name_p)):
        data_path = "./data/" + str(name_p[d_i])
        if (name_p[d_i] == 'SUSY'):
            data = np.loadtxt('../../../homeb/huangjy/SUSY.csv',
                              usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                              delimiter=",")
        elif (name_p[d_i] == 'HIGGS'):
            data = np.loadtxt('../../../homeb/huangjy/HIGGS.csv', delimiter=",")
        elif (name_p[d_i] == 'SIFT'):
            x = np.memmap('../../../homeb/huangjy/learn.bvecs', dtype='uint8', mode='r')
            d = x[:4].view('int32')[0]
            data = x.reshape(-1, d + 4)[:, 4:]
            data = np.array(data, dtype=float)
        else:
            data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')

        if (name_p != 'syn_1E7_2_3'):
            data = data[:, 0:data.shape[1] - 1]

        LS_FAST = LS(n_clusters=0, rounds=rounds, trans=20, batch=500, total_batch=100, minibatchround=10)

        for d_j in range(len(k_list)):
            print("k:", k_list[d_j])
            k = k_list[d_j]
            dataname = name_p[d_i]
            LS_FAST.n_clusters_ = k
            LS_FAST.rounds_ = rounds
            cost1_list = []
            time1_list = []
            for d_k in range(itera):
                # LS_FAST.batch_ = min(data.shape[0]*0.05, LS_FAST.batch_)
                # LS_FAST.batch_ = math.ceil(LS_FAST.batch_)
                t1 = time.time()
                centers_id = random.sample(range(0, data.shape[0]), k)
                centers_init = data[centers_id]
                centers_f = LS_FAST.LS_Bandit(data, centers_init, 'v2')
                t2 = time.time()
                time1_list.append(t2 - t1)
                cost1 = CheckCost(data, centers_f, 10000000)
                cost1_list.append(cost1)
            cost1_list = np.array(cost1_list)
            time1_list = np.array(time1_list)
            cost1_mean = np.mean(cost1_list)
            time1_mean = np.mean(time1_list)
            cost1_std = np.std(cost1_list)
            time1_std = np.std(time1_list)

            new = pand.DataFrame(
                [[dataname, "Ours", k, cost1_mean, cost1_std, time1_mean, time1_std]],
                columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            result = pand.concat([result, new])

            print("Ours", cost1_mean, cost1_std, time1_mean, time1_std)

            # cost3_list = []
            # time3_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     centers_f2 = LS_FAST.Fast_LS(data, centers_init)
            #     t2 = time.time()
            #     time3_list.append(t2 - t1)
            #     cost3 = CheckCost(data, centers_f2, 10000000)
            #     cost3_list.append(cost3)
            # cost3_list = np.array(cost3_list)
            # time3_list = np.array(time3_list)
            # cost3_mean = np.mean(cost3_list)
            # time3_mean = np.mean(time3_list)
            # cost3_std = np.std(cost3_list)
            # time3_std = np.std(time3_list)
            #
            # new = pand.DataFrame(
            #     [[dataname, "LSDS++", k, cost3_mean, cost3_std, time3_mean, time3_std]],
            #     columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            # result = pand.concat([result, new])
            # print("LSDS++", cost3_mean, cost3_std, time3_mean, time3_std)

    #result.to_csv(save_name, index=False)
    print('-' * 80)
