import numpy as np
import ot,random
import copy



# Gonzalez's Algorithm: 选择k个质心
def my_gonzalez_WD(locs_list,weis_list,k):
    center_locs_list = []
    center_weis_list = []
    n = len(locs_list)
    distance_matrix = np.ones((k,n)) * 10000000
    # Step 2: 选择后续的质心
    for i in range(k):
        # 计算每个点到已选质心的最小距离
        if i==0:
            tmp_ind = random.randint(0, n-1)
            center_locs = locs_list[tmp_ind]
            center_weis = weis_list[tmp_ind]           
        else:
            tmp_distance_vec = np.array([ot.emd2(weis_a,center_weis,ot.dist(locs_a,center_locs)) for locs_a,weis_a in zip(locs_list,weis_list)])
            distance_matrix[i-1] = tmp_distance_vec
            min_distances = np.min(distance_matrix,axis=0)
        # 选择与所有已选质心的最小距离最大的点作为新的质心
            tmp_ind = np.argmax(min_distances); tmp_ind = int(tmp_ind)
            center_locs = locs_list[tmp_ind]
            center_weis = weis_list[tmp_ind]   
            print("cost = =--------------------------",np.max(min_distances))   
        center_locs_list.append(center_locs)      
        center_weis_list.append(center_weis)    
    return center_locs_list,center_weis_list








# Gonzalez's Algorithm: 选择k个质心
def my_k_center_WD(locs_list,weis_list,center_locs_list,center_weis_list,LS_Itermax=20):
    distance_matrix = [ [ot.emd2(wei_a,wei_c,ot.dist(loc_a,loc_c)) for loc_a,wei_a in zip(locs_list,weis_list)] for loc_c,wei_c in zip(center_locs_list,center_weis_list)]
    cost = np.max(np.min(distance_matrix,axis=0))
    id_list = np.arange(len(locs_list))
    for _ in range(LS_Itermax):
        print("===========================================")
        print("cost = --------------------------------------------",cost)
        dis_vec = np.min(distance_matrix,axis=0) 
        proc_vec = dis_vec / np.sum(dis_vec)   
        tmp_ind= np.random.choice(id_list,p=proc_vec)
        alter_center_locs = locs_list[tmp_ind]
        alter_center_weis = weis_list[tmp_ind]
    
        # swapping     
        alter_distance_vec = [ot.emd2(wei_a,alter_center_weis,ot.dist(loc_a,alter_center_locs)) for loc_a,wei_a in zip(locs_list,weis_list) ]
        for i in range(k):
            print("i = ",i)
            tmp_distance_list = copy.deepcopy(distance_matrix)
            tmp_distance_list[i] = alter_distance_vec
            tmp_cost = np.max(np.min(tmp_distance_list,axis=0))
            print("tmp_cost = ",tmp_cost)
            if tmp_cost < cost:
                cost = tmp_cost
                center_locs_list[i] = alter_center_locs
                center_weis_list[i] = alter_center_weis
                distance_matrix = copy.deepcopy(tmp_distance_list)
                break
    return center_locs_list,center_weis_list








# 测试代码
if __name__ == "__main__":
    # 创建一个随机数据集（100个二维点）
    # X = np.random.rand(100, 2)
    m = 1000; n=100
    locs_list = np.random.rand(m,n,2) * 100
    weis_list = np.ones((m,n)) / n
    # print("weis_list = ",np.sum(weis_list,axis=1))
    # 选择3个质心
    k = 5
    center_index_list = np.random.choice(np.arange(m), k, replace=False)
    center_locs_list = locs_list[center_index_list]
    center_weis_list = weis_list[center_index_list]
    my_k_center_WD(locs_list,weis_list,center_locs_list,center_weis_list)

    center_locs_list,center_weis_list = my_gonzalez_WD(locs_list,weis_list,k)
    my_k_center_WD(locs_list,weis_list,center_locs_list,center_weis_list)
    # 打印结果
    print("选择的质心：")
    print("center_locs_list = ",len(center_locs_list))













