import numpy as np
import random
import math
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import pairwise_distances_chunked
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree, BallTree
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import pandas as pd1
from sklearn.cluster import MiniBatchKMeans
import time
import pandas as pand


def fast_local_search(data, init, ybxl, r, t, k, W):
    print("-----------------------------------Start Original Local Search------------------------------------")
    "Strategy 1: Use more space comlexity and heap structure for fast local search implementation"
    "Preprocessing steps for storing the distance and assignment information"
    INF = float("inf")
    nbrs = NearestNeighbors(n_neighbors=2).fit(data[init])
    # dist = (pairwise_distances(X, centers, metric="euclidean")) ** 2
    dist, ind = nbrs.kneighbors(data)
    dist = dist ** 2
    dist_2 = dist.copy()
    dist = dist[:, 0] * W
    dist_2[:, 0] = dist_2[:, 0] * W
    dist_2[:, 1] = dist_2[:, 1] * W
    init_ff = None
    # ind = np.argsort(dist, axis=1)
    affected_list = [[] for i in range(0, init.shape[0])]
    for i in range(0, data.shape[0]):
        affected_list[ind[i][0]].append(i)
    for i in range(0, len(affected_list)):
        affected_list[i] = np.array(affected_list[i], dtype=int)

    "Calculating the current clustering cost"
    cost_now = (dist_2[:, 0]).sum()

    "Construct the sampling distribution"
    prob_modified = dist_2[:, 0] / dist_2[:, 0].sum()
    # boosting_target = math.ceil(self.n_outliers_ * (1 + self.epsilon_))
    sample_range = [i for i in range(0, data.shape[0])]
    # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
    # prob = prob * factor_l
    # prob_id_large = (np.argwhere(prob>1))[:,0]
    # prob[prob_id_large] = 1
    # prob_modified = prob.copy() / (prob.copy()).sum()

    "Start the Local Search Process"
    # print("------------------------------Start Local Search-----------------------------", cost_now)

    fail = 0
    cost_glob = float("inf")

    data2 = data ** 2
    data2_sum = np.sum(data2, axis=1)

    "Fast Local Search"
    for i in range(0, r):
        "preparation"
        cost_min = INF
        swap_id = None

        "Construct the sampling distribution"
        "Sample one data point from the modified probability"
        if (data.shape[0] < 10000 or data.shape[0] > 20000):
            next_point = np.random.choice(sample_range, size=1, p=prob_modified)[0]
            # p1 = np.random.randint(low=0, high=5000, size=data.shape[0]) / 5000
            # pdiff = prob_modified - p1
            # #pdiff = np.random.binomial(1, prob_boost)
            # id_range = np.argwhere(pdiff > 0)
            # id_range = id_range[:, 0]
            # if (len(id_range) != 1):
            #     continue
            # next_point = id_range[0]


        else:
            p1 = np.random.randint(low=0, high=10000, size=data.shape[0]) / 10000
            pdiff = prob_modified - p1
            # pdiff = np.random.binomial(1, prob_boost)
            id_range = np.argwhere(pdiff > 0)
            id_range = id_range[:, 0]
            if (len(id_range) != 1):
                continue
            # print("Yes")
            next_point = id_range[0]
        centers_new = data[next_point]
        dist_tot_new = data2_sum + np.sum(centers_new ** 2, axis=0) - 2 * np.sum(data * centers_new, axis=1)
        dist_tot_new = dist_tot_new * W
        # print(dist_tot_new)
        # (pairwise_distances(X, centers_new, metric="euclidean"))[:,0] ** 2

        "Make the comparison between the distances of nearest and the new centers"
        dist_tot_new_modified = (dist_2.copy())[:, 0]
        dist_diff = dist_tot_new_modified - dist_tot_new
        dist_large_id = np.where(dist_diff > 0)
        dist_tot_new_modified[dist_large_id] = dist_tot_new[dist_large_id]
        # print(dist_tot_new_modified.sum())

        dist_tot_new_modified_sum = dist_tot_new_modified.sum()

        "Try to enumerate possible swap pairs"
        for j in range(0, k):
            "Find the points whose closest center are swapped out"
            "Now try to swap the j-th center out"
            dist_temp = dist_2.copy()
            id_affected = affected_list[j]

            "Compare the distances and calculate the new cost"
            dist_affected_modified = (dist_temp[id_affected])[:, 1]
            pd = dist_tot_new[id_affected]
            dist_diff = dist_affected_modified - pd
            id_large = np.where(dist_diff > 0)
            dist_affected_modified[id_large] = pd[id_large]
            cost_new = dist_tot_new_modified_sum - (
            dist_tot_new_modified[id_affected]).sum() + dist_affected_modified.sum()

            # if(i==100):
            #     centers_temp = centers.copy()
            #     centers_temp[j] = X[next_point]
            #     print("CheckCost", CheckCost(X, centers_temp), cost_new)

            "Judge if the swap is feasible"
            if (cost_new < cost_min):
                cost_min = cost_new
                swap_id = j

        "Check whether the minimum cost swap is feasible"
        if (cost_min < (1 - 1 / (100 * k)) * cost_now):
            # print("Check", cost_min)
            "Perform this swap"
            init[swap_id] = next_point
            cost_now = cost_min
            nbrs = NearestNeighbors(n_neighbors=2).fit(data[init])
            "Renew the distance structures"
            # center_new = (data[next_point])
            # pd = data2_sum + np.sum(centers_new ** 2, axis=0) - 2 * np.sum(data * next_point, axis=1)
            # pd = (pairwise_distances(X,center_new, metric="euclidean"))[:,0] ** 2
            # dist[:,swap_id] = pd
            # ind = np.argsort(dist, axis=1)
            dist, ind = nbrs.kneighbors(data)
            dist = dist ** 2
            dist_2 = dist.copy()
            dist = dist[:, 0] * W
            dist_2[:, 0] = dist_2[:, 0] * W
            dist_2[:, 1] = dist_2[:, 1] * W
            # print("Check", cost_now, dist.sum())
            affected_list = [[] for j1 in range(0, init.shape[0])]
            for j1 in range(0, data.shape[0]):
                affected_list[ind[j1][0]].append(j1)
            for j1 in range(0, len(affected_list)):
                affected_list[j1] = np.array(affected_list[j1], dtype=int)

            prob_modified = dist_2[:, 0].copy() / (dist_2[:, 0].copy()).sum()
            # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
            # prob = prob * factor_l
            # prob_id_large = (np.argwhere(prob>1))[:,0]
            # prob[prob_id_large] = 1
            # prob_modified = prob.copy() / (prob.copy()).sum()
            # print("Round",i,"Has A Swap", cost_now)
            # print("CheckCost", CheckCost(X, centers))
        else:
            continue

    if (cost_now < cost_glob):
        cost_glob = cost_now
        init_ff = data[init]
    else:
        fail += 1

    if (fail >= 0):
        if init_ff is not None:
            km = MiniBatchKMeans(n_clusters=k, init=init_ff, n_init=1, max_iter=10, batch_size=1024,
                                 compute_labels=False, reassignment_ratio=0)
            km.fit(data)
            centers = km.cluster_centers_
        else:
            km = MiniBatchKMeans(n_clusters=k, init='random', n_init=1, max_iter=10, batch_size=1024,
                                 compute_labels=False, reassignment_ratio=0)
            km.fit(data)
            centers = km.cluster_centers_
        return cost_glob, centers
        print("--------------------Final Clustering Cost---------------------", cost_glob)


def Projection(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return init_temp


def Projection1(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return indL1


def CheckCost(X, centers, batch):
    # cost_tot = 0
    # for i in range(0, X.shape[0]):
    #     print(i)
    #     dy_pair = np.sum(centers ** 2, axis=1) + np.sum(X[i] ** 2, axis=0) - 2*np.sum(centers * X[i], axis=1)
    #     dy = min(dy_pair)
    #     cost_tot += dy
    count = 0
    cost_tot = 0
    while (count < X.shape[0]):
        # print(count)
        count_next = min(count + batch, X.shape[0])
        pd = pairwise_distances(X[count:count_next], centers) ** 2
        pd = np.min(pd, axis=1)
        cost_tot += pd.sum()
        count += batch
    # TreeL1 = BallTree(centers, leaf_size=40)
    # #print(len(centers))
    # distL, _ = TreeL1.query(X, k=1)
    # distL = distL[:,0] ** 2
    # cost_now = distL.sum()
    # print(cost_now)
    return cost_tot


if __name__ == '__main__':
    # a = np.random.rand(100000)
    # a = a / a.sum()
    # Target = 200
    # epsilon = 1
    # Outlier = 100
    # k_list = [10, 20, 30, 50, 100]
    k_list = [3,5,10]
    itera = 10
    #save_name = 'LS_PLUS_Rround50_HIGGS_727.csv'
    #name_p = ['iris', 'seeds', 'glass', 'Who_FL', 'HCV_L', 'TRR_FL', 'urbanGB_L_10']

    name_p = ['rds']
    rounds = 400
    result = pand.DataFrame(columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
    for d_i in range(len(name_p)):
        data_path = "./data/" + str(name_p[d_i])
        if (name_p[d_i] == 'SUSY'):
            data = np.loadtxt('../../../homeb/huangjy/SUSY.csv',
                              usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                              delimiter=",")
        elif (name_p[d_i] == 'HIGGS'):
            data = np.loadtxt('../../../homeb/huangjy/HIGGS.csv', delimiter=",")
        elif (name_p[d_i] == 'SIFT'):
            x = np.memmap('../../../homeb/huangjy/learn.bvecs', dtype='uint8', mode='r')
            d = x[:4].view('int32')[0]
            data = x.reshape(-1, d + 4)[:, 4:]
            data = np.array(data, dtype=float)
        else:
            data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')

        if (name_p != 'syn_1E7_2_3'):
            data = data[:, 0:data.shape[1] - 1]

        for d_j in range(len(k_list)):
            print("k:", k_list[d_j])
            k = k_list[d_j]
            dataname = name_p[d_i]

            # cost1_list = []
            # time1_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     centers_f = LS_FAST.LS_Bandit(data, centers_init, 'v2')
            #     t2 = time.time()
            #     time1_list.append(t2 - t1)
            #     cost1 = CheckCost(data, centers_f, 10000000)
            #     cost1_list.append(cost1)
            # cost1_list = np.array(cost1_list)
            # time1_list = np.array(time1_list)
            # cost1_mean = np.mean(cost1_list)
            # time1_mean = np.mean(time1_list)
            # cost1_std = np.std(cost1_list)
            # time1_std = np.std(time1_list)

            # new = pand.DataFrame(
            #     [[dataname, "Ours", k, cost1_mean, cost1_std, time1_mean, time1_std]],
            #     columns = ['dataset', 'method','k','cost_mean','cost_std','time_mean','time_std'])
            # result = pand.concat([result, new])

            # print("Ours", cost1_mean, cost1_std, time1_mean, time1_std)

            # cost2_list = []
            # time2_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     km = MiniBatchKMeans(n_clusters=k, init=centers_init, n_init=1, max_iter=300, batch_size=1024,
            #                          compute_labels=False, reassignment_ratio=0)
            #     km.fit(data.copy())
            #     centers_f1 = km.cluster_centers_
            #     t2 = time.time()
            #     time2_list.append(t2 - t1)
            #     cost2 = CheckCost(data, centers_f1, 10000000)
            #     cost2_list.append(cost2)
            # cost2_list = np.array(cost2_list)
            # time2_list = np.array(time2_list)
            # cost2_mean = np.mean(cost2_list)
            # time2_mean = np.mean(time2_list)
            # cost2_std = np.std(cost2_list)
            # time2_std = np.std(time2_list)

            # new = pand.DataFrame(
            #     [[dataname, "Mini-Batch", k, cost2_mean, cost2_std, time2_mean, time2_std]],
            #     columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            # result = pand.concat([result, new])
            # print("Mini-Batch", cost2_mean, cost2_std, time2_mean, time2_std)

            cost3_list = []
            time3_list = []
            for d_k in range(itera):
                t1 = time.time()
                centers_id = random.sample(range(0, data.shape[0]), k)
                W = np.ones(data.shape[0])
                centers_id = np.array(centers_id, dtype=int)
                _, centers_f2 = fast_local_search(data, centers_id, 1 / 100, rounds, 2, k, W)
                t2 = time.time()
                time3_list.append(t2 - t1)
                cost3 = CheckCost(data, centers_f2, 10000000)
                cost3_list.append(cost3)
            cost3_list = np.array(cost3_list)
            time3_list = np.array(time3_list)
            cost3_mean = np.mean(cost3_list)
            time3_mean = np.mean(time3_list)
            cost3_std = np.std(cost3_list)
            time3_std = np.std(time3_list)

            new = pand.DataFrame(
                [[dataname, "LS++", k, cost3_mean, cost3_std, time3_mean, time3_std]],
                columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            result = pand.concat([result, new])
            print("LS++", cost3_mean, cost3_std, time3_mean, time3_std)

    #result.to_csv(save_name, index=False)
    print('-' * 80)
