import numpy as np
import ot,random
import copy
from numpy import arange,array
from __tool.tool import select_elements_byIndexes
from __alg.RWD_clustering import robustOT

# def myEmd(arg):
#     locs_a,weis_a,locs_b,weis_b = arg
#     # print(np.sum(weis_a),np.sum(weis_b))
#     costmatrix = ot.dist(locs_a,locs_b)
#     loss = ot.emd2(weis_a,weis_b,costmatrix)
#     return loss
    


def partition_list(input_list, partition_sizes):
    result = []
    start = 0
    for size in partition_sizes:
        end = start + size
        result.append(input_list[start:end])
        start = end
    return result


def merge_lists(list_of_lists):
    res = [item for sublist in list_of_lists for item in sublist]
    return res




def select_from_partitions(partition_sizes):
    result = []
    start = 0
    for size in partition_sizes:
        end = start + size
        # 从当前区间 [start, end) 随机选择一个数
        selected_number = random.randint(start, end - 1)
        result.append(selected_number)
        start = end
    return result





def max_index_in_partitions(input_list, partition_sizes):
    result = []
    input_list = list(input_list)
    start = 0
    for size in partition_sizes:
        end = start + size
        # 获取当前区间的最大值及其索引
        max_value = max(input_list[start:end])
        max_index = input_list.index(max_value, start, end)
        result.append(max_index)
        start = end
    return result





def repeat_elements(elements, counts):
    result = [elem for elem, count in zip(elements, counts) for _ in range(count)]
    return result





# # 示例用法
# elements = [1, 2, 3]
# counts = [5, 6, 7]
# result = repeat_elements(elements, counts)
# print(result)
# # 输出: [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3]




# # 示例用法
# input_list = [1, 3, 5, 2, 8, 6, 4, 7, 9, 0, 2, 11]
# partition_sizes = [3, 4, 5]
# max_indices = max_index_in_partitions(input_list, partition_sizes)
# print(max_indices)  # 输出: [2, 4, 11]






# # 示例用法
# partition_list = [3, 4, 5]
# selected_numbers = select_from_partitions(partition_list)
# print(selected_numbers)  # 输出: 每次运行会不同，例如 [1, 4, 10]



# # 示例用法
# list_of_lists = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
# merged_list = merge_lists(list_of_lists)
# print(merged_list)  # 输出: [1, 2, 3, 4, 5, 6, 7, 8, 9]








def coreset(locs_list,weis_list,label_list,radius=100,k_ddim=4,zeta_a=0.1,zeta_b=0):
    cs_locs_list = []
    cs_weis_list = []
    cs_label_list = []
    locs_list_list = [ locs_list ]
    weis_list_list = [ weis_list ]
    partition_sizes = [len(ll) for ll in locs_list_list]
    locs_list = merge_lists(locs_list_list)
    weis_list = merge_lists(weis_list_list)

    while len(locs_list_list)>0:
        distance_matrix = []
        center_indexes = []
        for kk in range(k_ddim):
            if kk==0:
                tmp_center_indexs = select_from_partitions(partition_sizes)
                # print("")
            else:
                min_distance_vec = np.min(distance_matrix,axis=0)
                tmp_center_indexs = max_index_in_partitions(min_distance_vec,partition_sizes)
            center_indexes = center_indexes + tmp_center_indexs
            center_index_list = repeat_elements(tmp_center_indexs,partition_sizes)   
            # arg_list = [[la,wa,locs_list[c_ind],weis_list[c_ind]] for la,wa,c_ind in zip(locs_list,weis_list,center_index_list)]
            arg_list = [[la,wa,locs_list[c_ind],weis_list[c_ind],zeta_a,zeta_b] for la,wa,c_ind in zip(locs_list,weis_list,center_index_list)]
                    
            tmp_distance_vec = []
            for arg in arg_list:
                # tmp = myEmd(arg)
                tmp = robustOT(arg)[1]
                tmp_distance_vec.append(tmp)
            distance_matrix.append(tmp_distance_vec)
        
        center_indexes = list(set(center_indexes))
        # cs_locs_list = cs_locs_list + list(array(locs_list)[center_indexes])
        # cs_weis_list = cs_weis_list + list(array(weis_list)[center_indexes])
        cs_locs_list = cs_locs_list + select_elements_byIndexes(locs_list,center_indexes) 
        cs_weis_list = cs_weis_list + select_elements_byIndexes(weis_list,center_indexes)  
        cs_label_list = cs_label_list + list(array(label_list)[center_indexes])
        
        min_distance_vec = np.min(distance_matrix,axis=0)
        cluster_indexes_list = np.argmin(distance_matrix,axis=0) 
        new_locs_list_list = []; new_weis_list_list = []
        for ii in range(len(partition_sizes)):
            # 加一个掩码解决
            mask = np.zeros(len(locs_list), dtype=bool); mask[sum(partition_sizes[:ii]):sum(partition_sizes[:ii+1])] = True
            tmp_index_list_list = [np.where((cluster_indexes_list==i) & (min_distance_vec > radius) & mask)[0] for i in range(k_ddim)]
            # print("")
            tmp_index_list_list = [ll for ll in tmp_index_list_list if len(list(ll))>0]
            # print("")
            # tmp_locs_list_list = [array(locs_list)[ind_l] for ind_l in zip(tmp_index_list_list) ]
            tmp_locs_list_list = [select_elements_byIndexes(locs_list,ind_l) for ind_l in tmp_index_list_list ]
            new_locs_list_list = new_locs_list_list + tmp_locs_list_list
            # tmp_weis_list_list = [array(weis_list)[ind_l] for ind_l in zip(tmp_index_list_list) ]
            tmp_weis_list_list = [select_elements_byIndexes(weis_list,ind_l) for ind_l in tmp_index_list_list ]
            new_weis_list_list = new_weis_list_list + tmp_weis_list_list
        # new_locs_list_list = [np.array(locs_list)[np.where(cluster_indexes_list==i)] for i in range(k_ddim)]
        new_locs_list_list = [list(ll) for ll in new_locs_list_list if len(list(ll)) > 0]  
        new_weis_list_list = [list(ll) for ll in new_weis_list_list if len(list(ll)) > 0]
        locs_list_list = new_locs_list_list
        weis_list_list = new_weis_list_list
        locs_list = merge_lists(locs_list_list)
        weis_list = merge_lists(weis_list_list)
        partition_sizes = [len(ll) for ll in locs_list_list]

        print("locs_list_list = ",[len(ll) for ll in locs_list_list])
        print("locs_list_list,cs_locs_list = ",sum([len(ll) for ll in locs_list_list]),len(cs_locs_list))
        print("")
    return None,None,None,cs_locs_list,cs_weis_list,cs_label_list



if __name__ == "__main__":
    m = 1000; n=100
    locs_list = np.random.rand(m,n,2) * 100
    weis_list = np.ones((m,n)) / n
    label_list = np.random.randint(low=0,high=6,size=m)
    k_ddim = 4
    radius = 100
    _,_,_,cs_locs_list,cs_weis_list,cs_label_list = coreset(locs_list,weis_list,label_list,radius=80,k_ddim=4)
    print(len(cs_locs_list),len(cs_weis_list),len(cs_label_list))








