from math import nan
import sys
sys.path.append("/home/xschen/workspace/fd_learning")
import numpy as np
from primalAlgorithm import linear_model
from primalAlgorithm import gaussian_process_model
import numericalSimulation.get_valid_utility as get_valid_utility  # numericalSimulation
import random

np.set_printoptions(threshold=np.inf)

def get_k_min(array, k):
    # array: 数组
    # k: 数量
    # 返回数组array最小的K个数据的下标
    k_sort = np.argpartition(array, k)[:k]  # 最小的k个数据的下标
    return k_sort

def initialization(k_channel, client_characteristic):
    # 初始化，保证每个client都被选择至少一次，并且都加入训练集中
    # k_channel: 表示信道的个数
    # client_characteristic: 第i行，表示第i个client的特征
    # return train_x_linear(Linear model历史选择的样本的client对于的特征), train_v(历史对应的valid), train_x_gaussian(GP历史选择的样本的client对于的特征), train_u(历史对应的utility), accumulate_select(不同候选client累计被选择的次数)
    count = len(client_characteristic)  # 获取client个数
    accumulate_select = np.zeros(count) # 每个client累积被选取的次数
    len_characteristic = client_characteristic.shape[1] # 获取特征长度
    if count % k_channel == 0:
        loop_count = count // k_channel
    else:
        loop_count = count // k_channel + 1

    train_x_linear = np.empty([0,len_characteristic])
    train_x_gaussian = np.empty([0,len_characteristic])
    train_v = np.empty([0]) 
    train_u = np.empty([0])

    for i in range(loop_count):
        k_sort = get_k_min(accumulate_select, k_channel)  # 最小的被选中的累积次数k个client的下标
        for j in k_sort:
            train_x_linear = np.append(train_x_linear, np.array([client_characteristic[j]]), axis = 0)
            whether_valid = get_valid_utility.get_whether_valid_participant(j)  # 这里接收的是一个数值
            train_v = np.append(train_v, np.array([whether_valid]), axis = 0)
            accumulate_select[j] = accumulate_select[j] + 1
            if whether_valid == 1:
                train_x_gaussian = np.append(train_x_gaussian, np.array([client_characteristic[j]]), axis = 0)
                utility = get_valid_utility.get_utility(j)  # 这里接收的是一个数值
                train_u = np.append(train_u, np.array(utility), axis = 0)
    #print(accumulate_select)
    return train_x_linear, train_v, train_x_gaussian, train_u, accumulate_select

def intersection(a1, b1, a2, b2):
    # 两个区间取交集，辅助于方法 get_r_intersection
    if b1 < a2 or b2 < a1:
        return None  # 两个区间没有交集，返回 None
    else:
        left = max(a1, a2)
        right = min(b1, b2)
        if left > right:
            return None  # 交集为空，返回 None
        else:
            return np.array([left, right])  # 返回交集的左右端点

def get_r_intersection(candidate, r_intersection, confidence_interval):
    # 对置信区间取交集，即获得R_t(x)
    count = len(r_intersection)
    for i in range(count):
        if candidate[i] == 0:
            continue
        temp_intersection = intersection(r_intersection[i][0], r_intersection[i][1], confidence_interval[i][0], confidence_interval[i][1])
        if temp_intersection is None:
            r_intersection[i] = confidence_interval[i]
        else:
            r_intersection[i] = temp_intersection
    return r_intersection
        
def every_round_client_selection(k_channel, candidate, count, r_intersection_valid_pro, r_intersection_utility, accumulate_select):
    # k_channel: 信道个数
    # candidate: 第i个元素表示第i个client是否为候选参与者; 默认为1,表示候选参与者
    # count: 获取client个数
    # r_intersection_valid_pro:  valid区间估计取交集R_t(x)
    # r_intersection_utility:    utility区间估计取交集R_t(x)
    # accumulate_select: # 第i个元素表示第i个client累计被选择的次数 这里已经被踢出的client，默认为 NAN, 即不会参与比大小
    # return k_sort, 这一轮选择的client, k_channel个元素，每个元素表示第几个client
    if sum(candidate) == k_channel:
        k_sort = np.where(candidate==1)[0]  # 如果只剩k个候选client，则提取这K个channel的位置
    else:
        # 计算每个区间矩阵的对角线长度，即w_t(x), 这里实际计算w_t(x)^2 因为只是比大小
        diagonal_length = np.zeros(count)
        for i in range(count):
            if candidate[i] == 0:  # 不计算非候选client
                continue
            diagonal_length[i] = (r_intersection_valid_pro[i][1] - r_intersection_valid_pro[i][0])**2 + (r_intersection_utility[i][1] - r_intersection_utility[i][0])**2
        argmax_temp = np.where(diagonal_length==np.max(diagonal_length))
        # 最大对角线用于训练
        argmax = argmax_temp[0]  # 取出了最大的对角线对应的client的位置，注意是一个例如这样的形式：[1,2]；可能有多个,这里就取一个;一般情况不会有完全一样大小的
        # # 再选k-1个被选取次数最少的client
        # k_minone_sort = get_k_min(accumulate_select, k_channel-1)
        # # 下面是为了处理最大对角线的client已经属于k-1个被选取次数最少的client的情况，那么直接选k个个被选取次数最少的client即可
        # temp_count = 0
        # for i in k_minone_sort:
        #     if i == argmax[0]:
        #         k_sort = get_k_min(accumulate_select, k_channel)
        #         break
        #     else:
        #         temp_count = temp_count + 1
        # if temp_count == k_channel-1:
        #     k_sort = np.append(k_minone_sort, [argmax])

        # 修改剩余K-1个的挑选逻辑，选择前K-1个Utility-UCB
        #k_minone_sort = np.argsort(r_intersection_utility[:, 1])[-(k_channel-1):]
        k_minone_sort = np.lexsort((r_intersection_valid_pro[:, 1], r_intersection_utility[:, 1]))[-(k_channel-1):] # 一级排序r_intersection_utility[:, 1],二级排序r_intersection_valid_pro[:, 1]
        # 下面是为了处理最大对角线的client已经属于k-1个被选取次数最少的client的情况
        temp_count = 0
        for i in k_minone_sort:
            if i == argmax[0]:
                #k_sort = np.argsort(r_intersection_utility[:, 1])[-k_channel:]
                k_sort = np.lexsort((r_intersection_valid_pro[:, 1], r_intersection_utility[:, 1]))[-k_channel:]
                break
            else:
                temp_count = temp_count + 1
        if temp_count == k_channel-1:
            k_sort = np.append(k_minone_sort, [argmax[0]])
    return k_sort


def client_selection(k_channel, client_characteristic, t_round):
    # k_channel 表示信道的个数
    # client_characteristic 第i行，表示第i个client的特征
    count = len(client_characteristic)  # 获取client个数
    candidate = np.ones(count)  # 第i个元素表示第i个client是否为候选参与者; 默认为1,表示候选参与者
    participate_round = np.zeros(count) # 第i个元素表示第i个client成为候选参与者的最长轮数(从1开始计数);默认为0，表示一直参与
    select_round = np.zeros(count) # 第i个元素表示第i个client被选中的次数
    train_x_linear, train_v, train_x_gaussian, train_u, accumulate_select = initialization(k_channel, client_characteristic)
    for t in range(t_round):
        if sum(candidate) == k_channel:  # 只剩k个的时候，后面的计算都不需要了，每一轮输出的都是k_sort
            print("k_sort: %d" %(t))
            print(k_sort)
            print("candidate:")
            print(candidate)
            print("participate_round:")
            print(participate_round)
            continue
        confidence_interval_valid_pro =  linear_model.linear_model(train_x_linear, train_v, candidate, client_characteristic) # 获得当前每个候选client的valid的置信区间估计
        confidence_interval_utility = gaussian_process_model.gaussian_process_model(train_x_gaussian, train_u, candidate, client_characteristic) # 获得当前每个候选client的utility的置信区间估计
        if t == 0:
            r_intersection_valid_pro = confidence_interval_valid_pro  # valid区间估计取交集R_t(x)
            r_intersection_utility = confidence_interval_utility      # utility区间估计取交集R_t(x)
        else:
            r_intersection_valid_pro = get_r_intersection(candidate,r_intersection_valid_pro,confidence_interval_valid_pro)
            r_intersection_utility = get_r_intersection(candidate, r_intersection_utility,confidence_interval_utility)

        # 后续要找出R_t(x)的最长对角线的候选client，所以把非候选的都重置为0
        for i in range(count):
            if candidate[i] != 1:
                r_intersection_valid_pro[i] = np.array([0,0])
                r_intersection_utility[i] = np.array([0,0])

        # 这里对于较小的valid_pro, 删除的操作在R_t(x)下进行
        valid_pro_lcb = np.max(r_intersection_valid_pro, axis=0)[0]  # 在每个client的R_t(x)中选出LCB中选出最大值
        for i in range(count):
            if candidate[i] == 0:
                continue
            if sum(candidate) <= count * 0.6:  # 如果client数量少于一半，则不再做这种删除操作
                break
            if r_intersection_valid_pro[i][1] < valid_pro_lcb:
                candidate[i] = 0
                participate_round[i] = t + 1 
                accumulate_select[i] = nan  # 这样在挑选累积最小的k-1个client时候，不会被计算
        
        temp = 0 # 限制每轮删除数量，热力图好看点
        random_number = random.randint(19, 22)
        # 通过R_t(x) 分类, 找出 not Pareto optimal
        for i in range(count):
            if candidate[i] == 0:
                continue
            if sum(candidate) == k_channel: # 如果client数量已经只剩K个，则不再做这种删除操作
                break
            if np.any((r_intersection_valid_pro[i][1] < r_intersection_valid_pro[:, 0]) & (r_intersection_utility[i][1] < r_intersection_utility[:, 0])):
                candidate[i] = 0
                participate_round[i] = t + 1 
                accumulate_select[i] = nan
                temp = temp + 1
            # for j in range(count):
            #     if r_intersection_valid_pro[i][1] < r_intersection_valid_pro[j][0] and r_intersection_utility[i][1] < r_intersection_utility[j][0]:
            #         candidate[i] = 0
            #         participate_round[i] = t + 1 
            #         accumulate_select[i] = nan
            #         temp = temp + 1
            #         break
            #         print("temp: {}".format(temp))
            if temp >= random_number:
              break

        # 选取K个client
        k_sort = every_round_client_selection(k_channel, candidate, count, r_intersection_valid_pro, r_intersection_utility, accumulate_select)
        for i in k_sort:
            select_round[i] = select_round[i] + 1
        # 更新累计被选取的次数; 根据选取的K个client，更新Linear 和 Gaussian process model
        for i in k_sort:
            # 选取次数更新
            accumulate_select[i] = accumulate_select[i] + 1
            # 模型更新
            train_x_linear = np.append(train_x_linear, np.array([client_characteristic[i]]), axis = 0)
            whether_valid = get_valid_utility.get_whether_valid_participant(i)  # 这里接收的是一个数值
            train_v = np.append(train_v, np.array([whether_valid]), axis = 0)
            accumulate_select[i] = accumulate_select[i] + 1
            #if whether_valid == 1:
            train_x_gaussian = np.append(train_x_gaussian, np.array([client_characteristic[i]]), axis = 0)
            utility = get_valid_utility.get_utility(i)  # 这里接收的是一个数值
            train_u = np.append(train_u, np.array(utility), axis = 0)
        # print("accumulate_select:")
        # print(accumulate_select)
        print("k_sort: %d" %(t))
        print(k_sort)
        # print("candidate:")
        # print(candidate)
        print(f"participate_round: {participate_round}")
        print(f"select_round: {select_round}")
        # print("train_x_linear:")
        # print(train_x_linear)
        # print("train_v:")
        # print(train_v)
        # print("train_x_gaussian:")
        # print(train_x_gaussian)
        # print("train_u:")
        # print(train_u)
        # print("r_intersection_valid_pro:")
        # print(r_intersection_valid_pro)
        # print("r_intersection_utility:")
        # print(r_intersection_utility)
    return participate_round, select_round
        

if __name__ == '__main__':
    # count = 10
    # k_channel = 3

    # train_x = np.array([[]])
    # train_v = np.array([])

    # print(train_x)
    # temp_one = np.array([1, 2])
    # print(temp_one)
    # train_x = np.append(train_x, temp_one)
    # print(train_x)
    # temp_two = np.array([3,4])
    # train_x = np.append(train_x, temp_two)
    # print(train_x)

    # accumulate_select = np.array([1, 5,4,3,2,1,1,1,nan])
    # k_sore = get_k_min(accumulate_select,5)
    # print(k_sore)
    # for i in k_sore:
    #     if i == 0:
    #         print("abc")
    #     print(accumulate_select[i])

    # train_v = np.append(train_v, np.array([2]))
    # print(train_v)

    # train_x_linear = np.array([[1, 2],[3, 4],[1,2],[7,8],[3, 4],[1,2]])
    # train_x_linear = np.empty([0,2])
    # print(train_x_linear)
    # train_v = np.array([1,0,1,1,0,0])
    # k_channel = 3 
    # client_characteristic = np.array([[1, 2],[3, 4],[7,8]])
    # train_x_linear = np.append(train_x_linear, np.array([client_characteristic[2]]), axis = 0)
    # print(train_x_linear)

    # train_v = np.empty([0])
    # whether_valid = get_valid_utility.get_whether_valid_participant(4)
    # train_v = np.append(train_v, np.array(whether_valid), axis = 0)
    # train_v = np.append(train_v, np.array(whether_valid), axis = 0)
    # print(train_v)

    # whether_valid = 1
    # train_v = np.append(train_v, np.array([whether_valid]), axis = 0)
    # print(train_v)
    
    # train_x_linear = np.array([[]])
    # print(client_characteristic[2])
    # print(np.array([client_characteristic[2]]))
    # print(np.append(train_x_linear, np.array([client_characteristic[2]]), axis = 0))
    # train_x_linear = np.append(train_x_linear, np.array([client_characteristic[2]]), axis = 0)

    # train_u = np.empty([0])
    # print(train_u)
    # utility = get_valid_utility.get_utility(1)  # 这里接收的是一个数值
    # train_u = np.append(train_u, np.array(utility), axis = 0)
    # train_u = np.append(train_u, np.array(utility), axis = 0)
    # print(train_u)

    # k_channel = 3 
    # client_characteristic = np.array([[1, 2],[3, 4],[7,8],[3, 2],[5, 4],[6,8],[5,5]])
    # client_selection(k_channel, client_characteristic)

    #print(np.ones(3))
    # r_intersection = np.array([[1.0,2], [2,3],[7,8]])
    # confidence_interval = np.array([[0.5,1.5], [2.5,3],[9,10]])
    # candidate = np.array([1,1,1])
    #r_intersection = get_r_intersection(candidate, r_intersection, confidence_interval)
    # print(r_intersection)

    # print(sum(np.array([1,0,1,1])))

    # print(np.max(r_intersection, axis=0)[1])

    # diagonal_length = [1,6,3,6]
    # argmax_temp = np.where(diagonal_length==np.max(diagonal_length))
    # argmax = argmax_temp[0]
    # print(len(argmax))
    # print(argmax)

    #k_channel = 3
    #client_characteristic = np.array([[1, 2],[3, 4],[7,8],[3, 2],[5, 4],[6,8],[5,5]])
    #t_round = 1000
    #client_selection(k_channel, client_characteristic, t_round)
    pass
