import numpy as np
import ot,random
import copy
from ot.unbalanced import sinkhorn_stabilized_unbalanced


def UOT(arg):
    loc_a,wei_a,loc_b,wei_b,reg,reg_a,reg_b = arg
    costmatrix = ot.dist(np.array(loc_a),np.array(loc_b))
    flowmatrix = sinkhorn_stabilized_unbalanced(wei_a,wei_b,costmatrix,reg=reg,reg_m=[reg_a,reg_b])
    # error_a = np.abs(wei_a - np.sum(flowmatrix,axis=1))
    # error_b = np.abs(wei_b - np.sum(flowmatrix,axis=0))
    # print("error_a,error_b,flowmatrix = ",sum(error_a),sum(error_b),np.sum(flowmatrix))
    flowmatrix = flowmatrix / np.sum(flowmatrix)
    loss = np.sum(costmatrix*flowmatrix)
    return flowmatrix,loss



# n1 = 1000; n2 = 800; dim = 3
# loc_a = np.random.rand(n1,dim)
# wei_a = np.ones(n1) / n1
# loc_b = np.random.rand(n2,dim)
# wei_b = np.ones(n2) / n2
# reg = 1; reg_a = 1; reg_b = 100000
# args = [loc_a,wei_a,loc_b,wei_b,reg,reg_a,reg_b]
# res = UOT(args)
# print(res[1])








# Gonzalez's Algorithm: 选择k个质心
def my_gonzalez_UOT(locs_list,weis_list,label_list,k,reg,reg_a,reg_b,MAXreg=10000,tau=10):
    center_locs_list = []
    center_weis_list = []
    center_label_list = []
    n = len(locs_list)
    distance_matrix = np.ones((k,n)) * 10000000
    # Step 2: 选择后续的质心
    for i in range(k):
        # 计算每个点到已选质心的最小距离
        if i==0:
            tmp_ind = random.randint(0, n-1)                     
            # # center_locs = locs_list[tmp_ind][:50]
            # # tmp_center_weis = weis_list[tmp_ind][:50] *2 
        else:
            tmp_distance_vec = np.array([UOT([locs_a,weis_a,center_locs_list[i-1],center_weis_list[i-1],reg,reg_a,MAXreg])[1] for locs_a,weis_a in zip(locs_list,weis_list)])
            distance_matrix[i-1] = tmp_distance_vec
            min_distances = np.min(distance_matrix,axis=0)
        # 选择与所有已选质心的最小距离最大的点作为新的质心
            tmp_ind = np.argmax(min_distances); tmp_ind = int(tmp_ind)
        center_locs = locs_list[tmp_ind]
        tmp_center_weis = weis_list[tmp_ind] 
        center_label = label_list[tmp_ind]  
            # print("cost = =--------------------------",np.max(min_distances))         
        # purification step    
        all_flowmatrix_loss_list = [UOT([locs_a,weis_a,center_locs,tmp_center_weis,reg,reg_a,reg_b]) for locs_a,weis_a in zip(locs_list,weis_list)]
        
        all_loss_list = [ll[1] for ll in all_flowmatrix_loss_list]
        threshold = sorted(all_loss_list)[tau]
        next_center_candidatesWeis_locs_weis = [[np.sum(ff_ll[0],axis=0),locs ,weis] for ff_ll,locs,weis in zip(all_flowmatrix_loss_list,locs_list,weis_list) if ff_ll[1] < threshold]
        cover_radius = 10000000; center_weis = None
        for nn in next_center_candidatesWeis_locs_weis:
            candidate_weis = nn[0]
            tmp_cover_radius = max([UOT([llww[1],llww[2],center_locs,candidate_weis,reg,reg_a,MAXreg])[1] for llww in next_center_candidatesWeis_locs_weis])
            if cover_radius > tmp_cover_radius:
                # print("cover_radius = ",cover_radius)
                cover_radius = tmp_cover_radius
                center_weis = candidate_weis
        center_locs_list.append(center_locs)
        center_weis_list.append(center_weis)
        center_label_list.append(center_label)
    return center_locs_list,center_weis_list,center_label_list





# Gonzalez's Algorithm: 选择k个质心
def my_k_center_UOT(locs_list,weis_list,label_list,center_locs_list,center_weis_list,center_label_list,reg,reg_a,reg_b,tau=10,LS_Itermax=20,MAXreg=10000):
    assert len(list(locs_list)) > tau, "dataset size is too small!"
    k = len(list(center_locs_list))
    distance_matrix = [ [UOT([loc_a,wei_a,loc_c,wei_c,reg,reg_a,MAXreg])[1] for loc_a,wei_a in zip(locs_list,weis_list)] for loc_c,wei_c in zip(center_locs_list,center_weis_list)]
    
    cost = np.max(np.min(distance_matrix,axis=0))
    id_list = np.arange(len(locs_list))
    for _ in range(LS_Itermax):
        # print("===========================================")
        # print("cost = --------------------------------------------",cost)
        dis_vec = np.min(distance_matrix,axis=0) 
        proc_vec = dis_vec / np.sum(dis_vec)   
        tmp_ind= np.random.choice(id_list,p=proc_vec)
        alter_center_locs = locs_list[tmp_ind]
        tmp_center_weis = weis_list[tmp_ind]
        alter_center_labels = label_list[tmp_ind]
    
    
        # purification step    
        all_flowmatrix_loss_list = [UOT([locs_a,weis_a,alter_center_locs,tmp_center_weis,reg,reg_a,reg_b]) for locs_a,weis_a in zip(locs_list,weis_list)]
        all_loss_list = [ll[1] for ll in all_flowmatrix_loss_list]
        threshold = sorted(all_loss_list)[tau]
        next_center_candidatesWeis_locs_weis = [[np.sum(ff_ll[0],axis=0),locs ,weis] for ff_ll,locs,weis in zip(all_flowmatrix_loss_list,locs_list,weis_list) if ff_ll[1] <= threshold]
        cover_radius = 10000000; alter_center_weis = None
        if next_center_candidatesWeis_locs_weis != []:
            for nn in next_center_candidatesWeis_locs_weis:
                candidate_weis = nn[0]
                tmp_cover_radius = max([UOT([llww[1],llww[2],alter_center_locs,candidate_weis,reg,reg_a,MAXreg])[1] for llww in next_center_candidatesWeis_locs_weis])
                # print("tmp_cover_radius = ",tmp_cover_radius)
                if cover_radius > tmp_cover_radius:
                    cover_radius = tmp_cover_radius
                    alter_center_weis = candidate_weis
                    # print("cover_radius = ",cover_radius)
        else:
            alter_center_weis = tmp_center_weis
        # swapping     
        alter_distance_vec = [UOT([loc_a,wei_a,alter_center_locs,alter_center_weis,reg,reg_a,MAXreg])[1] for loc_a,wei_a in zip(locs_list,weis_list) ]
        for i in range(k):
            # print("i = ",i)
            tmp_distance_list = copy.deepcopy(distance_matrix)
            tmp_distance_list[i] = alter_distance_vec
            tmp_cost = np.max(np.min(tmp_distance_list,axis=0))
            # print("tmp_cost = ",tmp_cost)
            if tmp_cost < cost:
                cost = tmp_cost
                center_locs_list[i] = alter_center_locs
                center_weis_list[i] = alter_center_weis
                center_label_list[i] = alter_center_labels
                distance_matrix = copy.deepcopy(tmp_distance_list)
                break
    return center_locs_list,center_weis_list,center_label_list








# 测试代码
if __name__ == "__main__":
    # 创建一个随机数据集（100个二维点）
    # X = np.random.rand(100, 2)
    m = 1000; n=100
    locs_list = np.random.rand(m,n,2) * 1
    weis_list = np.ones((m,n)) / n
    label_list = np.random.randint(low=0,high=6,size=m)
    # print("weis_list = ",np.sum(weis_list,axis=1))
    # 选择3个质心
    k = 5
    reg,reg_a,reg_b,MAXreg = 1,1,1,10000
    center_index_list = np.random.choice(np.arange(m), k, replace=False)
    center_locs_list = locs_list[center_index_list]
    center_weis_list = weis_list[center_index_list]
    center_label_list = label_list[center_index_list]
    my_k_center_UOT(locs_list,weis_list,label_list,center_locs_list,center_weis_list,center_label_list,reg,reg_a,reg_b,tau=10,MAXreg=MAXreg)

    center_locs_list,center_weis_list,center_label_list = my_gonzalez_UOT(locs_list,weis_list,label_list,k,reg,reg_a,reg_b,MAXreg=MAXreg)
    my_k_center_UOT(locs_list,weis_list,label_list,center_locs_list,center_weis_list,center_label_list,reg,reg_a,reg_b,tau=10,MAXreg=MAXreg)
    # 打印结果
    print("选择的质心：")
    print("center_locs_list = ",len(center_locs_list))













