import numpy as np
import random
import math
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import pairwise_distances_chunked
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree, BallTree
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import pandas as pd1
from sklearn.cluster import MiniBatchKMeans
import time
import pandas as pand
from banditpam import KMedoids





def fast_local_search(data, init, ybxl, r, t, k, W):
    print("-----------------------------------Start Original Local Search------------------------------------")
    "Strategy 1: Use more space comlexity and heap structure for fast local search implementation"
    "Preprocessing steps for storing the distance and assignment information"
    INF = float("inf")
    nbrs = NearestNeighbors(n_neighbors=2).fit(data[init])
    #dist = (pairwise_distances(X, centers, metric="euclidean")) ** 2
    dist, ind = nbrs.kneighbors(data)
    dist = dist ** 2
    dist_2 = dist.copy() 
    dist = dist[:, 0] * W
    dist_2[:,0] = dist_2[:, 0] * W
    dist_2[:,1] = dist_2[:, 1] * W
    init_ff = None
    # ind = np.argsort(dist, axis=1)
    affected_list = [[] for i in range(0,init.shape[0])]
    for i in range(0,data.shape[0]):
        affected_list[ind[i][0]].append(i)
    for i in range(0, len(affected_list)):
        affected_list[i] = np.array(affected_list[i], dtype=int)
    
    "Calculating the current clustering cost"
    cost_now = (dist_2[:,0]).sum()
        
    "Construct the sampling distribution"
    prob_modified = dist_2[:,0] / dist_2[:,0].sum()
    #boosting_target = math.ceil(self.n_outliers_ * (1 + self.epsilon_))
    sample_range = [i for i in range(0,data.shape[0])]
    #factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
    #prob = prob * factor_l
    #prob_id_large = (np.argwhere(prob>1))[:,0]
    #prob[prob_id_large] = 1
    #prob_modified = prob.copy() / (prob.copy()).sum()

    
    "Start the Local Search Process"
    #print("------------------------------Start Local Search-----------------------------", cost_now)
    
    fail = 0
    cost_glob = float("inf")
    
    data2 = data ** 2
    data2_sum = np.sum(data2, axis=1)
    
    "Fast Local Search"
    for i in range(0, r):
        "preparation"
        cost_min = INF
        swap_id = None
        
        
        
        "Construct the sampling distribution"
        "Sample one data point from the modified probability"
        if(data.shape[0] < 10000 or data.shape[0]>20000):
            next_point = np.random.choice(sample_range, size=1, p=prob_modified)[0]
            # p1 = np.random.randint(low=0, high=5000, size=data.shape[0]) / 5000
            # pdiff = prob_modified - p1
            # #pdiff = np.random.binomial(1, prob_boost)
            # id_range = np.argwhere(pdiff > 0)
            # id_range = id_range[:, 0]
            # if (len(id_range) != 1):
            #     continue
            # next_point = id_range[0]
        
            
        else:
            p1 = np.random.randint(low=0, high=10000, size=data.shape[0]) / 10000
            pdiff = prob_modified - p1
            #pdiff = np.random.binomial(1, prob_boost)
            id_range = np.argwhere(pdiff > 0)
            id_range = id_range[:, 0]
            if (len(id_range) != 1):
                continue
        #print("Yes")
            next_point = id_range[0]
        centers_new = data[next_point]
        dist_tot_new = data2_sum + np.sum(centers_new ** 2, axis=0) - 2 * np.sum(data * centers_new, axis=1)
        dist_tot_new = dist_tot_new * W
        #print(dist_tot_new)
        #(pairwise_distances(X, centers_new, metric="euclidean"))[:,0] ** 2
        
        "Make the comparison between the distances of nearest and the new centers"
        dist_tot_new_modified = (dist_2.copy())[:,0]
        dist_diff = dist_tot_new_modified - dist_tot_new
        dist_large_id = np.where(dist_diff>0)
        dist_tot_new_modified[dist_large_id] = dist_tot_new[dist_large_id]
        #print(dist_tot_new_modified.sum())
        
        dist_tot_new_modified_sum = dist_tot_new_modified.sum()

        
        "Try to enumerate possible swap pairs"
        for j in range(0, k):
            "Find the points whose closest center are swapped out"
            "Now try to swap the j-th center out"
            dist_temp = dist_2.copy()
            id_affected = affected_list[j]
   
            "Compare the distances and calculate the new cost"
            dist_affected_modified = (dist_temp[id_affected])[:,1]
            pd = dist_tot_new[id_affected]
            dist_diff = dist_affected_modified - pd
            id_large = np.where(dist_diff > 0)
            dist_affected_modified[id_large] = pd[id_large]
            cost_new =  dist_tot_new_modified_sum - (dist_tot_new_modified[id_affected]).sum() + dist_affected_modified.sum()
            
            
            # if(i==100):
            #     centers_temp = centers.copy()
            #     centers_temp[j] = X[next_point]
            #     print("CheckCost", CheckCost(X, centers_temp), cost_new)
            
            
            "Judge if the swap is feasible"
            if(cost_new < cost_min):
                cost_min = cost_new
                swap_id = j
            
        "Check whether the minimum cost swap is feasible"
        if(cost_min < (1-1/(100*k))*cost_now):
            #print("Check", cost_min)
            "Perform this swap"
            init[swap_id] = next_point
            cost_now = cost_min
            nbrs = NearestNeighbors(n_neighbors=2).fit(data[init])
            "Renew the distance structures"
            #center_new = (data[next_point])
            #pd = data2_sum + np.sum(centers_new ** 2, axis=0) - 2 * np.sum(data * next_point, axis=1)
            #pd = (pairwise_distances(X,center_new, metric="euclidean"))[:,0] ** 2
            #dist[:,swap_id] = pd
            #ind = np.argsort(dist, axis=1)
            dist, ind = nbrs.kneighbors(data)
            dist = dist ** 2
            dist_2 = dist.copy() 
            dist = dist[:, 0] * W
            dist_2[:,0] = dist_2[:, 0] * W
            dist_2[:,1] = dist_2[:, 1] * W
            #print("Check", cost_now, dist.sum())
            affected_list = [[] for j1 in range(0, init.shape[0])]
            for j1 in range(0,data.shape[0]):
                affected_list[ind[j1][0]].append(j1)
            for j1 in range(0, len(affected_list)):
                affected_list[j1] = np.array(affected_list[j1], dtype=int)            
            
            prob_modified = dist_2[:,0].copy() / (dist_2[:,0].copy()).sum()
            #factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
            #prob = prob * factor_l
            #prob_id_large = (np.argwhere(prob>1))[:,0]
            #prob[prob_id_large] = 1
            #prob_modified = prob.copy() / (prob.copy()).sum()
            #print("Round",i,"Has A Swap", cost_now)
            #print("CheckCost", CheckCost(X, centers))
        else:
            continue
    
    #print("Enter the Lloyd")
    nbrs = NearestNeighbors(n_neighbors=1).fit(data)
    # TreeL = KDTree(init_L.copy(), leaf_size=40)
    # _, indL = TreeL.query(data.copy(), k=1)
    indL = ind[:,0]
    #print(indL)
    while(1):
        #Create new centers
        init_L_new =  data[init.copy()]
        for i1 in range(0,k):
            id_i = np.argwhere(indL==i1)
            id_i = id_i[:,0]
            W_repeat = np.tile(W[id_i],(data.shape[1],1))
            W_repeat = W_repeat.transpose()
            #print(id_i)
            #print(np.sum(data[id_i] * W_repeat,axis=0))
            init_L_new[i1] = np.sum(data[id_i] * W_repeat,axis=0) / W[id_i].sum()
            #print("Check",init)
        #Find the nearest data points to approximate new centers
        #print("CHECK",init_L_new)
        #TreeL1 = KDTree(data.copy(),leaf_size=40)
        #_, indL1 = TreeL1.query(init_L_new, k=1)
        _, indL1 = nbrs.kneighbors(init_L_new)
        #print("CHECK",init_L_new.copy())
        indL1 = indL1[:,0]
        #print("Check For New Centers", indL1)
        init_temp = data[indL1]
        #Recalculate the cost
        #TreeL = KDTree(init_temp.copy(), leaf_size=40)
        #distL, indL = TreeL.query(data.copy(), k=1)
        nbrs1 = NearestNeighbors(n_neighbors=1).fit(init_temp)
        distL, indL = nbrs1.kneighbors(data)
        indL = indL[:,0]
        distL = distL[:,0] ** 2
        distL = distL * W
        if(distL.sum()<cost_now):
            cost_now = distL.sum()
            init = indL1.copy()
            #init_L = init_temp.copy()
            init_ff = init_temp.copy()
            #print("Lloyd Has A Swap", cost_now)
        else:
            #print("No Better",distL.sum())
            break
    
    

    if(cost_now<cost_glob):
        cost_glob = cost_now
    else:
        fail += 1

    
    
    
    if(fail>=0):
        if init_ff is not None:
            km = MiniBatchKMeans(n_clusters=k, init=init_ff, n_init=1, max_iter=10, batch_size=1024,
                                        compute_labels=False, reassignment_ratio=0)
            km.fit(data)
            centers = km.cluster_centers_
        else:
            km = MiniBatchKMeans(n_clusters=k, init='random', n_init=1, max_iter=10, batch_size=1024,
                                        compute_labels=False, reassignment_ratio=0)
            km.fit(data)
            centers = km.cluster_centers_
        return cost_glob, centers
        print("--------------------Final Clustering Cost---------------------", cost_glob)



def Projection(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return init_temp


def Projection1(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return indL1


def CheckCost(X, centers, batch):
    # cost_tot = 0
    # for i in range(0, X.shape[0]):
    #     print(i)
    #     dy_pair = np.sum(centers ** 2, axis=1) + np.sum(X[i] ** 2, axis=0) - 2*np.sum(centers * X[i], axis=1)
    #     dy = min(dy_pair)
    #     cost_tot += dy
    count = 0
    cost_tot = 0
    while (count < X.shape[0]):
        #print(count)
        count_next = min(count + batch, X.shape[0])
        pd = pairwise_distances(X[count:count_next], centers) ** 2
        pd = np.min(pd, axis=1)
        cost_tot += pd.sum()
        count += batch
    # TreeL1 = BallTree(centers, leaf_size=40)
    # #print(len(centers))
    # distL, _ = TreeL1.query(X, k=1)
    # distL = distL[:,0] ** 2
    # cost_now = distL.sum()
    # print(cost_now)
    return cost_tot


if __name__ == '__main__':
    # a = np.random.rand(100000)
    # a = a / a.sum()
    # Target = 200
    # epsilon = 1
    # Outlier = 100
    #k_list = [10]
    k_list = [3, 5, 10]
    itera = 5
    save_name = 'small_vary_k_Bandit.csv'
    name_p = ['iris', 'seeds','glass','Who_FL','HCV_L','TRR_FL','urbanGB_L_10']
    
    #name_p = ['rds','urbanGB_L','rng','syn_1E7_2_3','USC1990']
    rounds = 400
    result = pand.DataFrame(columns = ['dataset', 'method','k','cost_mean','cost_std','time_mean','time_std'])
    for d_i in range(len(name_p)):
        data_path = "../data/" + str(name_p[d_i])
        if(name_p[d_i] == 'SUSY'):
            data = np.loadtxt('../../../homeb/huangjy/SUSY.csv',
                              usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                              delimiter=",")
        elif(name_p[d_i] == 'HIGGS'):
            data = np.loadtxt('../../../homeb/huangjy/HIGGS.csv', delimiter=",")
        elif(name_p[d_i] == 'SIFT'):
            x = np.memmap('../../../homeb/huangjy/learn.bvecs', dtype='uint8', mode='r')
            d = x[:4].view('int32')[0]
            data = x.reshape(-1, d + 4)[:, 4:]
            data = np.array(data, dtype=float)
        else:
            data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')

        if(name_p != 'syn_1E7_2_3'):
            data = data[:, 0:data.shape[1] - 1]


        for d_j in range(len(k_list)):
            print("k:", k_list[d_j])
            k = k_list[d_j]
            dataname = name_p[d_i]

            # cost1_list = []
            # time1_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     centers_f = LS_FAST.LS_Bandit(data, centers_init, 'v2')
            #     t2 = time.time()
            #     time1_list.append(t2 - t1)
            #     cost1 = CheckCost(data, centers_f, 10000000)
            #     cost1_list.append(cost1)
            # cost1_list = np.array(cost1_list)
            # time1_list = np.array(time1_list)
            # cost1_mean = np.mean(cost1_list)
            # time1_mean = np.mean(time1_list)
            # cost1_std = np.std(cost1_list)
            # time1_std = np.std(time1_list)

            # new = pand.DataFrame(
            #     [[dataname, "Ours", k, cost1_mean, cost1_std, time1_mean, time1_std]],
            #     columns = ['dataset', 'method','k','cost_mean','cost_std','time_mean','time_std'])
            # result = pand.concat([result, new])

            # print("Ours", cost1_mean, cost1_std, time1_mean, time1_std)


            # cost2_list = []
            # time2_list = []
            # for d_k in range(itera):
            #     t1 = time.time()
            #     centers_id = random.sample(range(0, data.shape[0]), k)
            #     centers_init = data[centers_id]
            #     km = MiniBatchKMeans(n_clusters=k, init=centers_init, n_init=1, max_iter=300, batch_size=1024,
            #                          compute_labels=False, reassignment_ratio=0)
            #     km.fit(data.copy())
            #     centers_f1 = km.cluster_centers_
            #     t2 = time.time()
            #     time2_list.append(t2 - t1)
            #     cost2 = CheckCost(data, centers_f1, 10000000)
            #     cost2_list.append(cost2)
            # cost2_list = np.array(cost2_list)
            # time2_list = np.array(time2_list)
            # cost2_mean = np.mean(cost2_list)
            # time2_mean = np.mean(time2_list)
            # cost2_std = np.std(cost2_list)
            # time2_std = np.std(time2_list)

            # new = pand.DataFrame(
            #     [[dataname, "Mini-Batch", k, cost2_mean, cost2_std, time2_mean, time2_std]],
            #     columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            # result = pand.concat([result, new])
            # print("Mini-Batch", cost2_mean, cost2_std, time2_mean, time2_std)

            cost3_list = []
            time3_list = []
            for d_k in range(itera):
                t1 = time.time()
                kmed = KMedoids(n_medoids=k, algorithm="BanditPAM")
                kmed.fit(data, 'L2')
                centers_temp = kmed.medoids
                centers_temp = np.array(centers_temp, dtype=int)
                km = MiniBatchKMeans(n_clusters=k, init=data[centers_temp], n_init=1, max_iter=10, batch_size=1024,
                                            compute_labels=False, reassignment_ratio=0)
                km.fit(data)
                centers_f2 = km.cluster_centers_
                t2 = time.time()
                time3_list.append(t2 - t1)
                cost3 = CheckCost(data, centers_f2, 10000000)
                cost3_list.append(cost3)
            cost3_list = np.array(cost3_list)
            time3_list = np.array(time3_list)
            cost3_mean = np.mean(cost3_list)
            time3_mean = np.mean(time3_list)
            cost3_std = np.std(cost3_list)
            time3_std = np.std(time3_list)
            
            new = pand.DataFrame(
                [[dataname, "BanditPAM", k, cost3_mean, cost3_std, time3_mean, time3_std]],
                columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            result = pand.concat([result, new])
            print("LSDS++", cost3_mean, cost3_std, time3_mean, time3_std)









    result.to_csv(save_name, index=False)
    print('-' * 80)
