import numpy as np
import ot,random
import copy
from numpy import array


# Gonzalez's Algorithm: 选择k个质心
def my_gonzalez_WD(locs_list,weis_list,label_list,k):
    center_locs_list = []
    center_weis_list = []
    center_label_list = []
    n = len(locs_list)
    distance_matrix = np.ones((k,n)) * 10000000
    # Step 2: 选择后续的质心
    for i in range(k):
        # 计算每个点到已选质心的最小距离
        if i==0:
            tmp_ind = random.randint(0, n-1)                     
            # # center_locs = locs_list[tmp_ind][:50]
            # # tmp_center_weis = weis_list[tmp_ind][:50] *2 
        else:
            # tmp_distance_vec = np.array([robustOT([locs_a,weis_a,center_locs_list[i-1],center_weis_list[i-1],zeta_a,0,emd_Itermax])[1] for locs_a,weis_a in zip(locs_list,weis_list)])
            tmp_distance_vec = np.array([ot.emd2(array(weis_a),array(center_weis_list[i-1]),ot.dist(array(locs_a),array(center_locs_list[i-1]))) for locs_a,weis_a in zip(locs_list,weis_list)])
            distance_matrix[i-1] = tmp_distance_vec
            min_distances = np.min(distance_matrix,axis=0)
        # 选择与所有已选质心的最小距离最大的点作为新的质心
            tmp_ind = np.argmax(min_distances); tmp_ind = int(tmp_ind)
        center_locs = locs_list[tmp_ind]
        center_weis = weis_list[tmp_ind] 
        center_label = label_list[tmp_ind]  
            # print("cost = =--------------------------",np.max(min_distances))         
        center_locs_list.append(center_locs)
        center_weis_list.append(center_weis)
        center_label_list.append(center_label)
    return center_locs_list,center_weis_list,center_label_list





# Gonzalez's Algorithm: 选择k个质心
def my_k_center_WD(locs_list,weis_list,label_list,center_locs_list,center_weis_list,center_label_list,LS_Itermax=20):
    k = len(list(center_locs_list))
    # distance_matrix = [ [robustOT([loc_a,wei_a,loc_c,wei_c,zeta_a,0,emd_Itermax])[1] for loc_a,wei_a in zip(locs_list,weis_list)] for loc_c,wei_c in zip(center_locs_list,center_weis_list)]
    distance_matrix = [ [ot.emd2(array(wei_a),array(wei_c),ot.dist(array(loc_a),array(loc_c))) for loc_a,wei_a in zip(locs_list,weis_list)] for loc_c,wei_c in zip(center_locs_list,center_weis_list)]
    cost = np.max(np.min(distance_matrix,axis=0))
    id_list = np.arange(len(locs_list))
    for _ in range(LS_Itermax):
        # print("===========================================")
        # print("cost = --------------------------------------------",cost)
        dis_vec = np.min(distance_matrix,axis=0) 
        proc_vec = dis_vec / np.sum(dis_vec)   
        tmp_ind= np.random.choice(id_list,p=proc_vec)
        alter_center_locs = locs_list[tmp_ind]
        alter_center_weis = weis_list[tmp_ind]
        alter_center_labels = label_list[tmp_ind]
    
        # swapping     
        # alter_distance_vec = [robustOT([loc_a,wei_a,alter_center_locs,alter_center_weis,zeta_a,0,emd_Itermax])[1] for loc_a,wei_a in zip(locs_list,weis_list) ]
        alter_distance_vec = [ot.emd2(array(wei_a),array(alter_center_weis),ot.dist(array(loc_a),array(alter_center_locs))) for loc_a,wei_a in zip(locs_list,weis_list) ]
        for i in range(k):
            # print("i = ",i)
            tmp_distance_list = copy.deepcopy(distance_matrix)
            tmp_distance_list[i] = alter_distance_vec
            tmp_cost = np.max(np.min(tmp_distance_list,axis=0))
            # print("tmp_cost = ",tmp_cost)
            if tmp_cost < cost:
                cost = tmp_cost
                center_locs_list[i] = alter_center_locs
                center_weis_list[i] = alter_center_weis
                center_label_list[i] = alter_center_labels
                distance_matrix = copy.deepcopy(tmp_distance_list)
                break
    return center_locs_list,center_weis_list,center_label_list








# 测试代码
if __name__ == "__main__":
    # 创建一个随机数据集（100个二维点）
    # X = np.random.rand(100, 2)
    m = 1000; n=100
    locs_list = np.random.rand(m,n,2) * 100
    weis_list = np.ones((m,n)) / n
    label_list = np.random.randint(low=0,high=6,size=m)
    # print("weis_list = ",np.sum(weis_list,axis=1))
    # 选择3个质心
    k = 5
    center_index_list = np.random.choice(np.arange(m), k, replace=False)
    center_locs_list = locs_list[center_index_list]
    center_weis_list = weis_list[center_index_list]
    center_label_list = label_list[center_index_list]
    my_k_center_WD(locs_list,weis_list,label_list,center_locs_list,center_weis_list,center_label_list)

    center_locs_list,center_weis_list,center_label_list = my_gonzalez_WD(locs_list,weis_list,label_list,k)
    my_k_center_WD(locs_list,weis_list,label_list,center_locs_list,center_weis_list,center_label_list)
    # 打印结果
    print("选择的质心：")
    print("center_locs_list = ",len(center_locs_list))













