import numpy as np
import random
import math
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import pairwise_distances_chunked
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree, BallTree
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import pandas as pd1
from sklearn.cluster import MiniBatchKMeans
import time
import pandas as pand


class LS(object):

    def __init__(self, n_clusters=None, rounds=400, trans=10, batch=100, total_batch=10, minibatchround=40):

        self.n_clusters_ = n_clusters
        # self.dist_oracle_ = dist_oracle.DistQueryOracle()
        self.rounds_ = rounds
        self.trans_ = trans
        self.batch_ = batch
        self.total_batch_ = total_batch
        self.minibatchround_ = minibatchround

    def CheckCost(self, X, centers):
        Tree = BallTree(centers, leaf_size=40)
        dist, _ = Tree.query(X, k=1)
        dist = dist[:, 0] ** 2
        return dist.sum()

    def minibatch_kmeans(self, data, centers, batch, rounds, error):
        for j in range(0, rounds):
            batch_points = random.sample(range(0, data.shape[0]), batch)
            points = data[batch_points]
            Tree = BallTree(centers.copy(), leaf_size=40)
            _, ind = Tree.query(points, k=1)
            ind = ind[:, 0]
            new_centroids = np.zeros([centers.shape[0], centers.shape[1]])
            for i in range(0, centers.shape[0]):
                id_i = (np.where(ind == i))[0]
                cluster_points = data[id_i]
                eta = 1.0 / cluster_points
                if len(cluster_points) > 0:
                    new_centroids[i] = np.mean(cluster_points, axis=1) + (1.0 - eta) * centers[i]
                else:
                    new_centroids[i] = centers[i].copy()
            if np.linalg.norm(new_centroids - centers) < error:
                break
            else:
                centers = new_centroids.copy()

        return centers

    def minibatch_kmeans(self, data, centers, batch, rounds, error):
        cumu_time = 0
        ratio = 10
        for j in range(0, rounds):
            if (j > 2):
                if (ratio > (1 + error)):
                    batch = math.ceil(batch * (1 + error))
                    batch = min(batch, 10 * centers.shape[0] * math.log2(centers.shape[0]))
                    batch = math.ceil(batch)
                    # print("RaiseBatch", batch)
            batch_points = random.sample(range(0, data.shape[0]), batch)
            points = data[batch_points]

            Tree = BallTree(centers, leaf_size=40)
            pd, ind = Tree.query(points, k=1)
            pd = pd[:, 0] ** 2
            ind = ind[:, 0]
            cost_old_centers = pd.sum()
            cost_now = pd.sum()
            centers_now = np.zeros([centers.shape[0], data.shape[1]])
            count = np.zeros(centers.shape[0])
            for i1 in range(0, ind.shape[0]):
                centers_now[ind[i1]] += points[i1]
                count[ind[i1]] += 1

            for i1 in range(0, centers_now.shape[0]):
                if (count[i1] == 0):
                    centers_now[i1] = centers[i1].copy()
                else:
                    centers_now[i1] = centers_now[i1] / count[i1]

            time1 = time.time()
            pd_new = (pairwise_distances(points, centers_now, metric="euclidean")) ** 2
            pd_new = np.min(pd_new, axis=1)

            cost_new_centers = pd_new.sum()

            if (cost_old_centers / cost_new_centers < 1 + error):
                # print("Final Rounds", j, "Batch", batch, "Cost", cost_old_centers - cost_new_centers)
                if (cost_old_centers - cost_new_centers > 0):
                    return centers
                else:
                    return centers_now

            time2 = time.time()
            cumu_time += time2 - time1

            if (pd_new.sum() < cost_now):
                if (j >= 2):
                    ratio = cost_now / pd_new.sum()
                centers = centers_now.copy()
        return centers

    def Fast_LS(self, X, centers):
        "Strategy 1: Use more space comlexity and heap structure for fast local search implementation"
        "Preprocessing steps for storing the distance and assignment information"
        INF = 100000000000000000000000000005
        dist_2 = np.ones([X.shape[0], 2])
        dist = (pairwise_distances(X, centers, metric="euclidean")) ** 2
        ind = np.argsort(dist, axis=1)
        affected_list = [[] for i in range(0, centers.shape[0])]
        for i in range(0, X.shape[0]):
            dist_2[i][0] = dist[i][ind[i][0]]
            dist_2[i][1] = dist[i][ind[i][1]]
            affected_list[ind[i][0]].append(i)

        "Calculating the current clustering cost"
        cost_now = (dist_2[:, 0]).sum()

        "Construct the sampling distribution"
        prob_modified = dist_2[:, 0].copy() / (dist_2[:, 0].copy()).sum()
        # boosting_target = math.ceil(self.n_outliers_ * (1 + self.epsilon_))
        sample_range = [i for i in range(0, X.shape[0])]
        # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
        # prob = prob * factor_l
        # prob_id_large = (np.argwhere(prob>1))[:,0]
        # prob[prob_id_large] = 1
        # prob_modified = prob.copy() / (prob.copy()).sum()

        "Start the Local Search Process"
        # print("------------------------------Start Local Search-----------------------------", cost_now)

        centers_ff = None

        Searching = 0

        while (1):
            Searching += 1
            # print("-------------Searching Round-------------", Searching)
            "Fast Local Search"
            for i in range(0, self.rounds_):
                "preparation"
                cost_min = INF
                swap_id = None

                "Find the oversampling factor"

                "Construct the sampling distribution"

                "Sample one data point from the modified probability"
                next_point = np.random.choice(sample_range, size=1, p=prob_modified)[0]
                centers_new = (X[next_point]).reshape(1, -1)
                dist_tot_new = (pairwise_distances(X, centers_new, metric="euclidean"))[:, 0] ** 2

                "Make the comparison between the distances of nearest and the new centers"
                dist_tot_new_modified = (dist_2.copy())[:, 0]
                dist_diff = dist_tot_new_modified - dist_tot_new
                dist_large_id = (np.argwhere(dist_diff > 0))[:, 0]
                dist_tot_new_modified[dist_large_id] = dist_tot_new[dist_large_id]
                # print(dist_tot_new_modified.sum())

                "Finding the minimum id"
                next_numpy = data[next_point]
                next_numpy = next_numpy.reshape(1, -1)
                pd_k = pairwise_distances(centers, next_numpy)[:, 0]
                min_id = np.argmin(pd_k)
                rd_id = random.sample(range(0, len(centers)), 1)[0]
                # print(min_id, rd_id)

                "Try to enumerate possible swap pairs"
                for j in [min_id, rd_id]:
                    "Find the points whose closest center are swapped out"
                    "Now try to swap the j-th center out"
                    dist_temp = dist_2.copy()
                    id_affected = np.array(affected_list[j], dtype=int)

                    "Compare the distances and calculate the new cost"
                    dist_affected_modified = (dist_temp[id_affected])[:, 1]
                    pd = dist_tot_new[id_affected]
                    dist_diff = dist_affected_modified - pd
                    id_large = (np.argwhere(dist_diff > 0))[:, 0]
                    dist_affected_modified[id_large] = pd[id_large]
                    cost_new = dist_tot_new_modified.sum() - (
                        dist_tot_new_modified[id_affected]).sum() + dist_affected_modified.sum()

                    # if(i==100):
                    #     centers_temp = centers.copy()
                    #     centers_temp[j] = X[next_point]
                    #     print("CheckCost", CheckCost(X, centers_temp), cost_new)

                    "Judge if the swap is feasible"
                    if (cost_new < cost_min):
                        cost_min = cost_new
                        swap_id = j

                "Check whether the minimum cost swap is feasible"
                if (cost_min < (1 - 1 / (100 * self.n_clusters_)) * cost_now):
                    # print("Check", cost_min)
                    "Perform this swap"
                    cost_now = cost_min
                    centers[swap_id] = X[next_point]
                    "Renew the distance structures"
                    center_new = (X[next_point]).reshape(1, -1)
                    pd = (pairwise_distances(X, center_new, metric="euclidean"))[:, 0] ** 2
                    dist[:, swap_id] = pd
                    ind = np.argsort(dist, axis=1)
                    dist_2 = np.ones([X.shape[0], 2])
                    affected_list = [[] for j1 in range(0, centers.shape[0])]
                    for j1 in range(0, X.shape[0]):
                        dist_2[j1][0] = dist[j1][ind[j1][0]]
                        dist_2[j1][1] = dist[j1][ind[j1][1]]
                        affected_list[ind[j1][0]].append(j1)

                    prob_modified = dist_2[:, 0].copy() / (dist_2[:, 0].copy()).sum()
                    # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
                    # prob = prob * factor_l
                    # prob_id_large = (np.argwhere(prob>1))[:,0]
                    # prob[prob_id_large] = 1
                    # prob_modified = prob.copy() / (prob.copy()).sum()
                    # print("Round",i,"Has A Swap", cost_now)
                    # print("CheckCost", CheckCost(X, centers))
                else:
                    continue
            if (Searching >= 0):
                km = MiniBatchKMeans(n_clusters=k, init=centers, n_init=1, max_iter=10, batch_size=1024,
                                     compute_labels=False, reassignment_ratio=0)
                km.fit(data)
                centers = km.cluster_centers_
                return centers

def Projection(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return init_temp


def Projection1(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return indL1


def CheckCost(X, centers, batch):
    # cost_tot = 0
    # for i in range(0, X.shape[0]):
    #     print(i)
    #     dy_pair = np.sum(centers ** 2, axis=1) + np.sum(X[i] ** 2, axis=0) - 2*np.sum(centers * X[i], axis=1)
    #     dy = min(dy_pair)
    #     cost_tot += dy
    count = 0
    cost_tot = 0
    while (count < X.shape[0]):
        # print(count)
        count_next = min(count + batch, X.shape[0])
        pd = pairwise_distances(X[count:count_next], centers) ** 2
        pd = np.min(pd, axis=1)
        cost_tot += pd.sum()
        count += batch
    # TreeL1 = BallTree(centers, leaf_size=40)
    # #print(len(centers))
    # distL, _ = TreeL1.query(X, k=1)
    # distL = distL[:,0] ** 2
    # cost_now = distL.sum()
    # print(cost_now)
    return cost_tot


if __name__ == '__main__':
    # a = np.random.rand(100000)
    # a = a / a.sum()
    # Target = 200
    # epsilon = 1
    k_list = [3,5,10]
    itera = 10
    #save_name = 'lsds_small.csv'
    #name_p = ['iris', 'seeds', 'glass', 'Who_FL', 'HCV_L', 'TRR_FL', 'urbanGB_L_10']
    #name_p = ['tail1','tail2','nontail1','nontail2']
    #name_p = ['syn_1E7_2_3','USC1990','SUSY','HIGGS']
    name_p = ['rds']
    rounds = 400
    result = pand.DataFrame(columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
    for d_i in range(len(name_p)):
        data_path = "./data/" + str(name_p[d_i])
        if (name_p[d_i] == 'SUSY'):
            data = np.loadtxt('../../../homeb/huangjy/SUSY.csv',
                              usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                              delimiter=",")
        elif (name_p[d_i] == 'HIGGS'):
            data = np.loadtxt('../../../homeb/huangjy/HIGGS.csv', delimiter=",")
        elif (name_p[d_i] == 'SIFT'):
            x = np.memmap('../../../homeb/huangjy/learn.bvecs', dtype='uint8', mode='r')
            d = x[:4].view('int32')[0]
            data = x.reshape(-1, d + 4)[:, 4:]
            data = np.array(data, dtype=float)
        elif (name_p[d_i] == 'tail1'):
            data = np.load("tail1.npy")
        elif (name_p[d_i] == 'tail2'):
            data = np.load("tail2.npy")
        elif (name_p[d_i] == 'nontail1'):
            data = np.load("nontail1.npy")
        elif (name_p[d_i] == 'nontail2'):
            data = np.load("nontail2.npy")
        else:
            data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')

        if (name_p != 'syn_1E7_2_3'):
            data = data[:, 0:data.shape[1] - 1]

        LS_FAST = LS(n_clusters=0, rounds=int(0 * 10), trans=20, batch=500, total_batch=100, minibatchround=100)

        for d_j in range(len(k_list)):
            print("k:", k_list[d_j])
            k = k_list[d_j]
            dataname = name_p[d_i]
            LS_FAST.n_clusters_ = k
            LS_FAST.rounds_ = rounds

            # cost1_list = []
            # time1_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     centers_f = LS_FAST.LS_Bandit(data, centers_init, 'v2')
            #     t2 = time.time()
            #     time1_list.append(t2 - t1)
            #     cost1 = CheckCost(data, centers_f, 10000000)
            #     cost1_list.append(cost1)
            # cost1_list = np.array(cost1_list)
            # time1_list = np.array(time1_list)
            # cost1_mean = np.mean(cost1_list)
            # time1_mean = np.mean(time1_list)
            # cost1_std = np.std(cost1_list)
            # time1_std = np.std(time1_list)

            # new = pand.DataFrame(
            #     [[dataname, "Ours", k, cost1_mean, cost1_std, time1_mean, time1_std]],
            #     columns = ['dataset', 'method','k','cost_mean','cost_std','time_mean','time_std'])
            # result = pand.concat([result, new])

            # print("Ours", cost1_mean, cost1_std, time1_mean, time1_std)

            # cost2_list = []
            # time2_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     km = MiniBatchKMeans(n_clusters=k, init=centers_init, n_init=1, max_iter=300, batch_size=1024,
            #                          compute_labels=False, reassignment_ratio=0)
            #     km.fit(data.copy())
            #     centers_f1 = km.cluster_centers_
            #     t2 = time.time()
            #     time2_list.append(t2 - t1)
            #     cost2 = CheckCost(data, centers_f1, 10000000)
            #     cost2_list.append(cost2)
            # cost2_list = np.array(cost2_list)
            # time2_list = np.array(time2_list)
            # cost2_mean = np.mean(cost2_list)
            # time2_mean = np.mean(time2_list)
            # cost2_std = np.std(cost2_list)
            # time2_std = np.std(time2_list)

            # new = pand.DataFrame(
            #     [[dataname, "Mini-Batch", k, cost2_mean, cost2_std, time2_mean, time2_std]],
            #     columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            # result = pand.concat([result, new])
            # print("Mini-Batch", cost2_mean, cost2_std, time2_mean, time2_std)

            cost3_list = []
            time3_list = []
            for d_k in range(itera):
                t1 = time.time()
                centers_id = random.sample(range(0, data.shape[0]), k)
                centers_init = data[centers_id]
                centers_f2 = LS_FAST.Fast_LS(data, centers_init)
                t2 = time.time()
                time3_list.append(t2 - t1)
                cost3 = CheckCost(data, centers_f2, 10000000)
                cost3_list.append(cost3)
            cost3_list = np.array(cost3_list)
            time3_list = np.array(time3_list)
            cost3_mean = np.mean(cost3_list)
            time3_mean = np.mean(time3_list)
            cost3_std = np.std(cost3_list)
            time3_std = np.std(time3_list)

            new = pand.DataFrame(
                [[dataname, "LSDS++", k, cost3_mean, cost3_std, time3_mean, time3_std]],
                columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            result = pand.concat([result, new])
            print("LSDS++", cost3_mean, cost3_std, time3_mean, time3_std)

    #result.to_csv(save_name, index=False)
    print('-' * 80)
