import numpy as np
import random
import math
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import pairwise_distances_chunked
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree, BallTree
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import pandas as pd1
from sklearn.cluster import MiniBatchKMeans
import time
import pandas as pand

def pairwise_batch(X, centers, batch_size=10_000_000):
    """
    Compute squared minimum pairwise distances between X and centers in batches.

    Parameters:
    - X: np.ndarray of shape (n_samples, n_features)
    - centers: np.ndarray of shape (n_centers, n_features)
    - batch_size: int, the number of points to process per batch

    Returns:
    - squared_min_dists: np.ndarray of shape (n_samples,), squared distances to closest center
    """
    n = X.shape[0]
    squared_min_dists = np.empty(n)

    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        X_batch = X[start:end]
        dists = pairwise_distances(X_batch, centers)
        min_dists = np.min(dists, axis=1)
        squared_min_dists[start:end] = min_dists ** 2

    return squared_min_dists

def fast_local_search(data, init, ybxl, r, t, k, W):
    print("-----------------------------------Start Original Local Search------------------------------------")
    "Strategy 1: Use more space comlexity and heap structure for fast local search implementation"
    "Preprocessing steps for storing the distance and assignment information"
    INF = float("inf")
    nbrs = NearestNeighbors(n_neighbors=2).fit(data[init])
    # dist = (pairwise_distances(X, centers, metric="euclidean")) ** 2
    dist, ind = nbrs.kneighbors(data)
    dist = dist ** 2
    dist_2 = dist.copy()
    dist = dist[:, 0] * W
    dist_2[:, 0] = dist_2[:, 0] * W
    dist_2[:, 1] = dist_2[:, 1] * W
    init_ff = None

    "Calculating the current clustering cost"
    cost_now = (dist_2[:, 0]).sum()

    "Construct the sampling distribution"
    prob_modified = dist_2[:, 0] / dist_2[:, 0].sum()
    # boosting_target = math.ceil(self.n_outliers_ * (1 + self.epsilon_))
    sample_range = [i for i in range(0, data.shape[0])]
    # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
    # prob = prob * factor_l
    # prob_id_large = (np.argwhere(prob>1))[:,0]
    # prob[prob_id_large] = 1
    # prob_modified = prob.copy() / (prob.copy()).sum()

    "Start the Local Search Process"
    # print("------------------------------Start Local Search-----------------------------", cost_now)

    fail = 0
    cost_glob = float("inf")

    data2 = data ** 2
    data2_sum = np.sum(data2, axis=1)

    "Fast Local Search"
    for i in range(0, r):
        "preparation"
        cost_min = INF
        swap_id = None

        "Construct the sampling distribution"
        "Sample one data point from the modified probability"
        if (data.shape[0] < 10000 or data.shape[0] > 20000):
            next_point = np.random.choice(sample_range, size=1, p=prob_modified)[0]
        else:
            p1 = np.random.randint(low=0, high=10000, size=data.shape[0]) / 10000
            pdiff = prob_modified - p1
            id_range = np.argwhere(pdiff > 0)
            id_range = id_range[:, 0]
            if (len(id_range) != 1):
                continue
            # print("Yes")
            next_point = id_range[0]

        "Try to enumerate possible swap pairs"
        for j in range(0, k):
            "Find the points whose closest center are swapped out"
            "Now try to swap the j-th center out"
            centers_new_id = init.copy()
            centers_new_id[j] = next_point
            centers_now = data[centers_new_id]
            pd = pairwise_batch(data, centers_now)
            cost_new = pd.sum()
            "Judge if the swap is feasible"
            if (cost_new < cost_min):
                cost_min = cost_new
                swap_id = j

        "Check whether the minimum cost swap is feasible"
        if (cost_min < (1 - 1 / (100 * k)) * cost_now):
            # print("Check", cost_min)
            "Perform this swap"
            init[swap_id] = next_point
            cost_now = cost_min
            nbrs = NearestNeighbors(n_neighbors=2).fit(data[init])
            "Renew the distance structures"
            centers_now_id = init.copy()
            centers_now = data[centers_now_id]
            init_ff = centers_now.copy()
            pd = pairwise_batch(data, centers_now)

            prob_modified = pd / pd.sum()
            # factor_l = self.Fast_Oversamling_Factor_Finding(prob, boosting_target, None, self.n_outliers_, self.epsilon_, self.delta_)
            # prob = prob * factor_l
            # prob_id_large = (np.argwhere(prob>1))[:,0]
            # prob[prob_id_large] = 1
            # prob_modified = prob.copy() / (prob.copy()).sum()
            # print("Round",i,"Has A Swap", cost_now)
            # print("CheckCost", CheckCost(X, centers))
        else:
            continue

    if (cost_now < cost_glob):
        cost_glob = cost_now
    else:
        fail += 1

    if (fail >= 0):
        if init_ff is not None:
            km = MiniBatchKMeans(n_clusters=k, init=init_ff, n_init=1, max_iter=10, batch_size=1024,
                                 compute_labels=False, reassignment_ratio=0)
            km.fit(data)
            centers = km.cluster_centers_
        else:
            km = MiniBatchKMeans(n_clusters=k, init='random', n_init=1, max_iter=10, batch_size=1024,
                                 compute_labels=False, reassignment_ratio=0)
            km.fit(data)
            centers = km.cluster_centers_
        return cost_glob, centers
        print("--------------------Final Clustering Cost---------------------", cost_glob)


def Projection(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return init_temp


def Projection1(X, centers):
    TreeL1 = BallTree(X, leaf_size=40)
    _, indL1 = TreeL1.query(centers, k=1)
    indL1 = indL1[:, 0]
    init_temp = X[indL1]
    return indL1


def CheckCost(X, centers, batch):
    # cost_tot = 0
    # for i in range(0, X.shape[0]):
    #     print(i)
    #     dy_pair = np.sum(centers ** 2, axis=1) + np.sum(X[i] ** 2, axis=0) - 2*np.sum(centers * X[i], axis=1)
    #     dy = min(dy_pair)
    #     cost_tot += dy
    count = 0
    cost_tot = 0
    while (count < X.shape[0]):
        # print(count)
        count_next = min(count + batch, X.shape[0])
        pd = pairwise_distances(X[count:count_next], centers) ** 2
        pd = np.min(pd, axis=1)
        cost_tot += pd.sum()
        count += batch
    # TreeL1 = BallTree(centers, leaf_size=40)
    # #print(len(centers))
    # distL, _ = TreeL1.query(X, k=1)
    # distL = distL[:,0] ** 2
    # cost_now = distL.sum()
    # print(cost_now)
    return cost_tot


if __name__ == '__main__':
    # a = np.random.rand(100000)
    # a = a / a.sum()
    # Target = 200
    # epsilon = 1
    # Outlier = 100
    # k_list = [10, 20, 30, 50, 100]
    k_list = [3, 5, 10]
    itera = 10
    save_name = 'LSPLUS_ORG_sift.csv'
    #name_p = ['syn_1E7_2_3','USC1990','SUSY','HIGGS']

    name_p = ['rds']
    rounds = 400
    result = pand.DataFrame(columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
    for d_i in range(len(name_p)):
        data_path = "./data/" + str(name_p[d_i])
        if (name_p[d_i] == 'SUSY'):
            data = np.loadtxt('../../../homeb/huangjy/SUSY.csv',
                              usecols=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18),
                              delimiter=",")
        elif (name_p[d_i] == 'HIGGS'):
            data = np.loadtxt('../../../homeb/huangjy/HIGGS.csv', delimiter=",")
        elif (name_p[d_i] == 'SIFT'):
            x = np.memmap('../../../homeb/huangjy/learn.bvecs', dtype='uint8', mode='r')
            d = x[:4].view('int32')[0]
            data = x.reshape(-1, d + 4)[:, 4:]
            data = np.array(data, dtype=float)
        else:
            data = np.loadtxt(data_path, delimiter=',', encoding='utf-8-sig')

        if (name_p != 'syn_1E7_2_3'):
            data = data[:, 0:data.shape[1] - 1]

        for d_j in range(len(k_list)):
            print("k:", k_list[d_j])
            k = k_list[d_j]
            dataname = name_p[d_i]

            cost3_list = []
            time3_list = []
            for d_k in range(itera):
                t1 = time.time()
                centers_id = random.sample(range(0, data.shape[0]), k)
                W = np.ones(data.shape[0])
                centers_id = np.array(centers_id, dtype=int)
                _, centers_f2 = fast_local_search(data, centers_id, 1 / 100, rounds, 2, k, W)
                t2 = time.time()
                time3_list.append(t2 - t1)
                cost3 = CheckCost(data, centers_f2, 10000000)
                cost3_list.append(cost3)
            cost3_list = np.array(cost3_list)
            time3_list = np.array(time3_list)
            cost3_mean = np.mean(cost3_list)
            time3_mean = np.mean(time3_list)
            cost3_std = np.std(cost3_list)
            time3_std = np.std(time3_list)

            new = pand.DataFrame(
                [[dataname, "LS++", k, cost3_mean, cost3_std, time3_mean, time3_std]],
                columns=['dataset', 'method', 'k', 'cost_mean', 'cost_std', 'time_mean', 'time_std'])
            result = pand.concat([result, new])
            print("LS++", cost3_mean, cost3_std, time3_mean, time3_std)

    #result.to_csv(save_name, index=False)
    print('-' * 80)
