import numpy as np
import ot,random
import copy
from multiprocessing import Pool


def robustOT(args):
    loc_a,wei_a,loc_b,wei_b,zeta_a,zeta_b,emd_Itermax = args
    # print("wei_a = ",sum(wei_a))
    # print("wei_b = ",sum(wei_b))
    costmatrix = ot.dist(np.array(loc_a),np.array(loc_b))
    s1,s2 = costmatrix.shape
    r_costmatrix = np.zeros((s1+1,s2+1))
    r_costmatrix[:-1,:-1] = costmatrix
    r_wei_a = np.zeros(s1+1); r_wei_b = np.zeros(s2+1)
    r_wei_a[:-1] = wei_a / (1-zeta_a); r_wei_a[-1] = zeta_b / (1-zeta_b)
    r_wei_b[:-1] = wei_b / (1-zeta_b); r_wei_b[-1] = zeta_a / (1-zeta_a)
    r_flowmatrix = ot.emd(r_wei_a,r_wei_b,r_costmatrix,emd_Itermax)
    loss = np.sum(r_costmatrix*r_flowmatrix)
    flowmatrix = r_flowmatrix[:-1,:-1]
    # #----- debug------
    # # if np.abs(np.sum(flowmatrix) - 1) > 0.0001 :
    # #     print("flowmatrix = ",np.sum(flowmatrix))
    # #     print(np.sum(wei_a),np.sum(wei_b))
    # # print("loss = ",loss)
    flowmatrix = flowmatrix / np.sum(flowmatrix)
    return flowmatrix,loss





# n1 = 1000; n2 = 800; dim = 3; emd_Itermax=100000
# loc_a = np.random.rand(n1,dim)
# wei_a = np.ones(n1) / n1
# loc_b = np.random.rand(n2,dim)
# wei_b = np.ones(n2) / n2
# zeta_a = 0.1; zeta_b = 0.2

# args = [loc_a,wei_a,loc_b,wei_b,zeta_a,zeta_b,emd_Itermax]
# robustOT(args)







# Gonzalez's Algorithm: 选择k个质心
def my_gonzalez_RWD(locs_list,weis_list,k,zeta_a,zeta_b,emd_Itermax=100000,tau=10,poolNum=32):
    center_locs_list = []
    center_weis_list = []
    n = len(locs_list)
    distance_matrix = np.ones((k,n)) * 10000000
    # Step 2: 选择后续的质心
    for i in range(k):
        # 计算每个点到已选质心的最小距离
        if i==0:
            tmp_ind = random.randint(0, n-1)
            center_locs = locs_list[tmp_ind]
            tmp_center_weis = weis_list[tmp_ind]   
                     
            # # center_locs = locs_list[tmp_ind][:50]
            # # tmp_center_weis = weis_list[tmp_ind][:50] *2 
        else:
            # tmp_distance_vec = np.array([robustOT([locs_a,weis_a,center_locs_list[i-1],center_weis_list[i-1],zeta_a,0,emd_Itermax])[1] for locs_a,weis_a in zip(locs_list,weis_list)])
            tmp_arg_list = [[locs_a,weis_a,center_locs_list[i-1],center_weis_list[i-1],zeta_a,0,emd_Itermax] for locs_a,weis_a in zip(locs_list,weis_list)]
            with Pool(poolNum) as p:
                tmp_vec = p.map(robustOT,tmp_arg_list)
            tmp_distance_vec = [t[1] for t in tmp_vec]
            distance_matrix[i-1] = tmp_distance_vec
            min_distances = np.min(distance_matrix,axis=0)
        # 选择与所有已选质心的最小距离最大的点作为新的质心
            tmp_ind = np.argmax(min_distances); tmp_ind = int(tmp_ind)
            center_locs = locs_list[tmp_ind]
            tmp_center_weis = weis_list[tmp_ind]   
            # print("cost = =--------------------------",np.max(min_distances))         
        # purification step    
        # all_flowmatrix_loss_list = [robustOT([locs_a,weis_a,center_locs,tmp_center_weis,zeta_a,zeta_b,emd_Itermax]) for locs_a,weis_a in zip(locs_list,weis_list)]
        tmp_arg_list = [[locs_a,weis_a,center_locs,tmp_center_weis,zeta_a,zeta_b,emd_Itermax] for locs_a,weis_a in zip(locs_list,weis_list)]
        with Pool(poolNum) as p:
            all_flowmatrix_loss_list = p.map(robustOT,tmp_arg_list)
        
        all_loss_list = [ll[1] for ll in all_flowmatrix_loss_list]
        threshold = sorted(all_loss_list)[tau]
        next_center_candidatesWeis_locs_weis = [[np.sum(ff_ll[0],axis=0),locs ,weis] for ff_ll,locs,weis in zip(all_flowmatrix_loss_list,locs_list,weis_list) if ff_ll[1] < threshold]
        cover_radius = 10000000; center_weis = None
        for nn in next_center_candidatesWeis_locs_weis:
            candidate_weis = nn[0]
            tmp_cover_radius = max([robustOT([llww[1],llww[2],center_locs,candidate_weis,zeta_a,0,emd_Itermax])[1] for llww in next_center_candidatesWeis_locs_weis])
            if cover_radius > tmp_cover_radius:
                # print("cover_radius = ",cover_radius)
                cover_radius = tmp_cover_radius
                center_weis = candidate_weis
        center_locs_list.append(center_locs)
        center_weis_list.append(center_weis)
        # center_locs_list.append(next_centroid_index)
    return center_locs_list,center_weis_list





# Gonzalez's Algorithm: 选择k个质心
def my_k_center_RWD(locs_list,weis_list,center_locs_list,center_weis_list,zeta_a,zeta_b,emd_Itermax=100000,tau=10,LS_Itermax=20,poolNum=32):
    assert len(list(locs_list)) > tau, "dataset size is too small!"
    k = len(list(center_locs_list))
    # distance_matrix = [ [robustOT([loc_a,wei_a,loc_c,wei_c,zeta_a,0,emd_Itermax])[1] for loc_a,wei_a in zip(locs_list,weis_list)] for loc_c,wei_c in zip(center_locs_list,center_weis_list)]
    distance_matrix = []
    for loc_c,wei_c in zip(center_locs_list,center_weis_list):
        tmp_arg_list = [[loc_a,wei_a,loc_c,wei_c,zeta_a,0,emd_Itermax] for loc_a,wei_a in zip(locs_list,weis_list)]
        with Pool(poolNum) as p:
            tmp_res_vec = p.map(robustOT,tmp_arg_list)
            tmp_dist_vec = [res[1] for res in tmp_res_vec]
        distance_matrix.append(tmp_dist_vec)
    
    cost = np.max(np.min(distance_matrix,axis=0))
    id_list = np.arange(len(locs_list))
    for _ in range(LS_Itermax):
        # print("===========================================")
        # print("cost = --------------------------------------------",cost)
        dis_vec = np.min(distance_matrix,axis=0) 
        proc_vec = dis_vec / np.sum(dis_vec)   
        tmp_ind= np.random.choice(id_list,p=proc_vec)
        alter_center_locs = locs_list[tmp_ind]
        tmp_center_weis = weis_list[tmp_ind]
    
    
        # purification step    
        # all_flowmatrix_loss_list = [robustOT([locs_a,weis_a,alter_center_locs,tmp_center_weis,zeta_a,zeta_b,emd_Itermax]) for locs_a,weis_a in zip(locs_list,weis_list)]
        tmp_arg_list = [[locs_a,weis_a,alter_center_locs,tmp_center_weis,zeta_a,zeta_b,emd_Itermax] for locs_a,weis_a in zip(locs_list,weis_list)]
        with Pool(poolNum) as p:
            all_flowmatrix_loss_list = p.map(robustOT,tmp_arg_list)
        all_loss_list = [ll[1] for ll in all_flowmatrix_loss_list]
        threshold = sorted(all_loss_list)[tau]
        next_center_candidatesWeis_locs_weis = [[np.sum(ff_ll[0],axis=0),locs ,weis] for ff_ll,locs,weis in zip(all_flowmatrix_loss_list,locs_list,weis_list) if ff_ll[1] <= threshold]
        cover_radius = 10000000; alter_center_weis = None
        if next_center_candidatesWeis_locs_weis != []:
            for nn in next_center_candidatesWeis_locs_weis:
                candidate_weis = nn[0]
                tmp_cover_radius = max([robustOT([llww[1],llww[2],alter_center_locs,candidate_weis,zeta_a,0,emd_Itermax])[1] for llww in next_center_candidatesWeis_locs_weis])
                # print("tmp_cover_radius = ",tmp_cover_radius)
                if cover_radius > tmp_cover_radius:
                    cover_radius = tmp_cover_radius
                    alter_center_weis = candidate_weis
                    # print("cover_radius = ",cover_radius)
        else:
            alter_center_weis = tmp_center_weis
        # swapping     
        # alter_distance_vec = [robustOT([loc_a,wei_a,alter_center_locs,alter_center_weis,zeta_a,0,emd_Itermax])[1] for loc_a,wei_a in zip(locs_list,weis_list) ]
        tmp_arg_list = [[loc_a,wei_a,alter_center_locs,alter_center_weis,zeta_a,0,emd_Itermax] for loc_a,wei_a in zip(locs_list,weis_list)]
        with Pool(poolNum) as p:
            tmp_res_list = p.map(robustOT,tmp_arg_list)
        alter_distance_vec = [res[1] for res in tmp_res_list ]
        for i in range(k):
            # print("i = ",i)
            tmp_distance_list = copy.deepcopy(distance_matrix)
            tmp_distance_list[i] = alter_distance_vec
            tmp_cost = np.max(np.min(tmp_distance_list,axis=0))
            # print("tmp_cost = ",tmp_cost)
            if tmp_cost < cost:
                cost = tmp_cost
                center_locs_list[i] = alter_center_locs
                center_weis_list[i] = alter_center_weis
                distance_matrix = copy.deepcopy(tmp_distance_list)
                break
    return center_locs_list,center_weis_list








# 测试代码
if __name__ == "__main__":
    # 创建一个随机数据集（100个二维点）
    # X = np.random.rand(100, 2)
    m = 1000; n=100
    locs_list = np.random.rand(m,n,2) * 100
    weis_list = np.ones((m,n)) / n
    # print("weis_list = ",np.sum(weis_list,axis=1))
    # 选择3个质心
    k = 5
    zeta_a,zeta_b = 0.1, 0.1
    center_index_list = np.random.choice(np.arange(m), k, replace=False)
    center_locs_list = locs_list[center_index_list]
    center_weis_list = weis_list[center_index_list]
    my_k_center_RWD(locs_list,weis_list,center_locs_list,center_weis_list,zeta_a,zeta_b,emd_Itermax=100000,tau=10)

    center_locs_list,center_weis_list = my_gonzalez_RWD(locs_list,weis_list,k,zeta_a,zeta_b)
    my_k_center_RWD(locs_list,weis_list,center_locs_list,center_weis_list,zeta_a,zeta_b,emd_Itermax=100000,tau=10)
    # 打印结果
    print("选择的质心：")
    print("center_locs_list = ",len(center_locs_list))













