import numpy as np
import random
import math
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import pairwise_distances_chunked
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree, BallTree
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import pandas as pd1
from sklearn.cluster import MiniBatchKMeans
import time
import pandas as pand
from simsimd import cdist, DistancesTensor
from itertools import combinations


def generate_candidate_set(center_list):
    list_f = []
    for i in range(1, len(center_list) + 1):
        for j in combinations(center_list, i):
            list_f.append(list(j))
    return list_f


def fast_local_search(data, init, k, ybxl, r, t, c, alpha, epsilon, W, single, multi, triple, eta, rounds):
    # print("-----------------------------------Start Fast Local Search------------------------------------")
    groups = 1
    sample_size = math.ceil(3 * math.log2(k / epsilon) / epsilon)
    sample_size = min(math.ceil(data.shape[0]), sample_size)
    if (data.shape[0] < 2000):
        sample_size = math.ceil(0.5 * data.shape[0])
    elif (data.shape[0] < 8000):
        sample_size = math.ceil(0.06 * data.shape[0])
    elif (data.shape[0] > 30000):
        sample_size = math.ceil(0.2 * data.shape[0])
    sample_size = math.ceil(3 * math.log2(k / epsilon) / epsilon)
    # print("sample_size", sample_size)
    cost_glob = float("inf")
    id_range_f = range(0, data.shape[0])
    fail = 0
    # Parallelization
    # pool1 = Pool(processes=MAX_PROCESSOR, initializer=init_worker, initargs=(data_g, data_shape, w_g, w_shape))
    init_ff = -1
    while (1):
        init_np = data[init]
        # dist = (pairwise_distances(data, init_np, metric="euclidean")) ** 2
        dist = cdist(data, init_np, metric="sqeuclidean")
        dist = np.min(dist, axis=1) * W
        prob = dist / dist.sum()
        cost_now = dist.sum()
        # print("Check_init", cost_now, CheckCost(data, init_np))
        for i in range(0, rounds):
            prob_boost = prob * 2
            if (data.shape[0] < 20000):
                p1 = np.random.randint(low=0, high=10000, size=data.shape[0]) / 10000
            else:
                p1 = np.random.rand(data.shape[0])
            pdiff = prob_boost - p1
            # pdiff = np.random.binomial(1, prob_boost)
            id_range = np.argwhere(pdiff > 0)
            id_range = id_range[:, 0]
            if (len(id_range) < 1 or len(id_range) > 2):
                continue
            next_points = generate_candidate_set(list(id_range))
            # print("NTPoints", next_points)
            # print(next_points)
            id_out = []
            id_in = []

            pair = []
            for f in range(0, len(next_points)):
                if (len(next_points[f]) == 1):
                    pair = pair + [[[single[n1]], next_points[f]] for n1 in range(0, len(single))]
                elif (len(next_points[f]) == 2):
                    pair = pair + [[multi[n1], next_points[f]] for n1 in range(0, len(multi))]
                else:
                    pair = pair + [[triple[n1], next_points[f]] for n1 in range(0, len(triple))]

            Min1 = float("inf")
            id_out_f = None
            id_in_f = None
            # print(pair)
            for s in range(0, groups):
                # rd_num = np.random.randint(1, data.shape[0]-sample_size-1)
                # id_sample = np.arange(rd_num, rd_num+sample_size)
                # print(sample_size)
                id_sample = random.sample(id_range_f, sample_size)
                # id_sample = np.array(id_sample, dtype=int)
                # id_sample = np.random.choice(data.shape[0], size=sample_size,replace=False)
                # id_sample = np.random.choice(data.shape[0], size=sample_size)
                # print(id_sample)
                data_s = (data[id_sample])
                # print(data_s - data)
                for jj1 in range(0, len(pair)):
                    id_in = pair[jj1][1]
                    id_out = pair[jj1][0]
                    # print(id_in, id_out)
                    # print("In and Out", id_in, id_out)
                    init_new = init.copy()
                    for jj2 in range(0, len(id_in)):
                        init_new[id_out[jj2]] = id_in[jj2]
                    # print("new", init_new)
                    init_new_np = data[init_new]
                    dist1 = np.min(cdist(data_s, init_new_np, metric="sqeuclidean"), axis=1) * W[id_sample]
                    cost1 = dist1.sum()
                    # cost1 = dist1.sum()
                    if (cost1 < Min1):
                        # print(cost1)
                        Min1 = cost1
                        id_out_f = id_out
                        id_in_f = id_in

            # print("Check", id_out_f, id_in_f)
            init_f = init.copy()
            for n1 in range(0, len(id_in_f)):
                init_f[id_out_f[n1]] = id_in_f[n1]
            init_f_np = (data[init_f])

            dist1 = np.min(cdist(data, init_f_np, metric="sqeuclidean"), axis=1) * W

            cost_next = dist1.sum()
            # print("CheckCost", cost_next, cost_now)
            if (cost_next < cost_now):
                cost_now = cost_next
                # print("CheckCost", cost_next, CheckCost(data, data[init_f]))
                # print("Min", cost_now, Min1)
                prob = dist1 / dist1.sum()
                init = init_f.copy()
                init_ff = init.copy()
                # print("Round", i, "Has A Swap", cost_now)

        # init_L = (data.copy())[init]
        # TreeL = KDTree(init_L.copy(), leaf_size=40)
        # _, indL = TreeL.query(data.copy(), k=1)
        # indL = indL[:, 0]
        # nbrs = NearestNeighbors(n_neighbors=1).fit(data)
        # while (1):
        #     # Create new centers
        #     init_L_new = init_L.copy()
        #     for i1 in range(0, k):
        #         id_i = np.argwhere(indL == i1)
        #         id_i = id_i[:, 0]
        #         W_repeat = (np.tile(np.ones(id_i.shape[0]), (data.shape[1], 1))).T
        #         init_L_new[i1] = np.sum(data[id_i] * W_repeat, axis=0) / W[id_i].sum()
        #         # print("Check",init)
        #     # Find the nearest data points to approximate new centers
        #     # print("CHECK",init_L_new)
        #     _, indL1 = nbrs.kneighbors(init_L_new)
        #     # TreeL1 = KDTree(data.copy(),leaf_size=40)
        #     # _, indL1 = TreeL1.query(init_L_new, k=1)
        #     # print("CHECK",init_L_new.copy())
        #     indL1 = indL1[:, 0]
        #     # print("Check For New Centers", indL1)
        #     init_temp = data[indL1]
        #     # Recalculate the cost
        #     # TreeL = KDTree(init_temp, leaf_size=40)
        #     # distL, indL = TreeL.query(data, k=1)
        # 
        #     nbrs = NearestNeighbors(n_neighbors=1).fit(init_temp)
        #     distL, indL = nbrs.kneighbors(data)
        # 
        #     indL = indL[:, 0]
        #     distL = (distL[:, 0] ** 2) * W
        #     if (distL.sum() < cost_now):
        #         cost_now = distL.sum()
        #         init = indL1.copy()
        #         init_L = init_temp.copy()
        #         init_ff = init_temp.copy()
        #     else:
        #         break

        if (cost_now < cost_glob):
            cost_glob = cost_now
        else:
            fail += 1

        if (fail >= 0):
            if init_ff is not None:
                # init_ff = np.array(init_ff, dtype=int)
                km = MiniBatchKMeans(n_clusters=k, init=data[init_ff], n_init=1, max_iter=10, batch_size=1024,
                                     compute_labels=False, reassignment_ratio=0)
                km.fit(data)
                centers = km.cluster_centers_
            else:
                km = MiniBatchKMeans(n_clusters=k, init='random', n_init=1, max_iter=10, batch_size=1024,
                                     compute_labels=False, reassignment_ratio=0)
                km.fit(data)
                centers = km.cluster_centers_
            return cost_glob, centers
        print("--------------------Final Clustering Cost---------------------", cost_glob)


def Projection(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return init_temp


def Projection1(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return indL1


def CheckCost(X, centers, batch):
    # cost_tot = 0
    # for i in range(0, X.shape[0]):
    #     print(i)
    #     dy_pair = np.sum(centers ** 2, axis=1) + np.sum(X[i] ** 2, axis=0) - 2*np.sum(centers * X[i], axis=1)
    #     dy = min(dy_pair)
    #     cost_tot += dy
    count = 0
    cost_tot = 0
    while (count < X.shape[0]):
        # print(count)
        count_next = min(count + batch, X.shape[0])
        pd = pairwise_distances(X[count:count_next], centers) ** 2
        pd = np.min(pd, axis=1)
        cost_tot += pd.sum()
        count += batch
    # TreeL1 = BallTree(centers, leaf_size=40)
    # #print(len(centers))
    # distL, _ = TreeL1.query(X, k=1)
    # distL = distL[:,0] ** 2
    # cost_now = distL.sum()
    # print(cost_now)
    return cost_tot


if __name__ == '__main__':
    # a = np.random.rand(100000)
    # a = a / a.sum()
    # Target = 200
    # epsilon = 1
    # Outlier = 100
    #k_list = [10, 20, 30]
    k_list = [3, 5, 10]
    itera = 10
    #save_name = 'MLS_Rounds400_sift.csv'
    #name_p = ['iris', 'seeds', 'glass', 'Who_FL', 'HCV_L', 'TRR_FL', 'urbanGB_L_10']

    #name_p = ['syn_1E7_2_3','USC1990','SUSY','HIGGS']
    name_p = ['rds']
    rounds = 400
    result = pand.DataFrame(columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
    for d_i in range(len(name_p)):
        data_path = "./data/" + str(name_p[d_i])
        if (name_p[d_i] == 'SUSY'):
            data = np.loadtxt('../../../homeb/huangjy/SUSY.csv',
                              usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                              delimiter=",")
        elif (name_p[d_i] == 'HIGGS'):
            data = np.loadtxt('../../../homeb/huangjy/HIGGS.csv', delimiter=",")
        elif (name_p[d_i] == 'SIFT'):
            x = np.memmap('../../../homeb/huangjy/learn.bvecs', dtype='uint8', mode='r')
            d = x[:4].view('int32')[0]
            data = x.reshape(-1, d + 4)[:, 4:]
            data = np.array(data, dtype=float)
        else:
            data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')

        if (name_p != 'syn_1E7_2_3'):
            data = data[:, 0:data.shape[1] - 1]

        for d_j in range(len(k_list)):
            print("k:", k_list[d_j])
            k = k_list[d_j]
            dataname = name_p[d_i]

            # cost1_list = []
            # time1_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     centers_f = LS_FAST.LS_Bandit(data, centers_init, 'v2')
            #     t2 = time.time()
            #     time1_list.append(t2 - t1)
            #     cost1 = CheckCost(data, centers_f, 10000000)
            #     cost1_list.append(cost1)
            # cost1_list = np.array(cost1_list)
            # time1_list = np.array(time1_list)
            # cost1_mean = np.mean(cost1_list)
            # time1_mean = np.mean(time1_list)
            # cost1_std = np.std(cost1_list)
            # time1_std = np.std(time1_list)

            # new = pand.DataFrame(
            #     [[dataname, "Ours", k, cost1_mean, cost1_std, time1_mean, time1_std]],
            #     columns = ['dataset', 'method','k','cost_mean','cost_std','time_mean','time_std'])
            # result = pand.concat([result, new])

            # print("Ours", cost1_mean, cost1_std, time1_mean, time1_std)

            # cost2_list = []
            # time2_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     km = MiniBatchKMeans(n_clusters=k, init=centers_init, n_init=1, max_iter=300, batch_size=1024,
            #                          compute_labels=False, reassignment_ratio=0)
            #     km.fit(data.copy())
            #     centers_f1 = km.cluster_centers_
            #     t2 = time.time()
            #     time2_list.append(t2 - t1)
            #     cost2 = CheckCost(data, centers_f1, 10000000)
            #     cost2_list.append(cost2)
            # cost2_list = np.array(cost2_list)
            # time2_list = np.array(time2_list)
            # cost2_mean = np.mean(cost2_list)
            # time2_mean = np.mean(time2_list)
            # cost2_std = np.std(cost2_list)
            # time2_std = np.std(time2_list)

            # new = pand.DataFrame(
            #     [[dataname, "Mini-Batch", k, cost2_mean, cost2_std, time2_mean, time2_std]],
            #     columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            # result = pand.concat([result, new])
            # print("Mini-Batch", cost2_mean, cost2_std, time2_mean, time2_std)

            cost3_list = []
            time3_list = []
            single = [i for i in range(0, k)]
            multi = [[i, j] for i in range(0, k) for j in range(i + 1, k)]
            triple = [[i, j, j1] for i in range(0, k) for j in range(i + 1, k) for j1 in range(j + 1, k)]
            for d_k in range(itera):
                t1 = time.time()
                centers_id = random.sample(range(0, data.shape[0]), k)
                W = np.ones(data.shape[0])
                centers_id = np.array(centers_id, dtype=int)
                _, centers_f2 = fast_local_search(data, centers_id.copy(), k, 1 / 100, 10, 2, 10, 2, 0.5, W.copy(),
                                                  single.copy(), multi.copy(), triple.copy(), 0.5, rounds)

                t2 = time.time()
                time3_list.append(t2 - t1)
                cost3 = CheckCost(data, centers_f2, 10000000)
                cost3_list.append(cost3)
            cost3_list = np.array(cost3_list)
            time3_list = np.array(time3_list)
            cost3_mean = np.mean(cost3_list)
            time3_mean = np.mean(time3_list)
            cost3_std = np.std(cost3_list)
            time3_std = np.std(time3_list)

            new = pand.DataFrame(
                [[dataname, "MLS", k, cost3_mean, cost3_std, time3_mean, time3_std]],
                columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            result = pand.concat([result, new])
            print("MLS", cost3_mean, cost3_std, time3_mean, time3_std)

    #result.to_csv(save_name, index=False)
    print('-' * 80)
