import logging
import math
from collections import OrderedDict
from random import Random
from scipy.spatial.distance import pdist, squareform
from .utils.auxiliary_to_suv import *
from math import nan

import numpy as np
import pandas as pd

from auxiliary_utils.client_characteristic import *


def create_training_selector_with_suv(args):
    return _training_selector(args)

class _training_selector(object):
    """SUV's training selector
    """
    def __init__(self, args, sample_seed=233):

        np.random.seed(sample_seed)

        self.totalClients = OrderedDict()
        self.training_round = 0

        self.args = args

        self.utility_type = args.utility_type

        self.k_channel = args.k_channel  # acquire the number of selected clients in every round
        self.client_num = args.client_num   # acquire the number of all Clients
        self.max_client_data_size = args.max_client_data_size
        self.min_client_data_size = args.min_client_data_size
        self.len_characteristic = 6  # acquire the length of client characteristic 
        self.client_characteristic = []

        self.candidate = np.ones(self.client_num) # The i-th element indicates whether the i-th client is a candidate participant; the default is 1, indicating a candidate participant
        self.cur_candidate_num = args.client_num
        self.accumulate_select = np.zeros(self.client_num) # 每个client累积被选取的次数

        self.train_x_linear = np.empty([0,self.len_characteristic])   # 第i行，表示第i次的选择的最长对角线对应的client的特征
        self.train_v = np.empty([0])                                  # 第i个，表示第i次的valid participant

        self.train_x_gaussian = np.empty([0,self.len_characteristic]) # 第i行，表示第i次的选择的最长对角线对应的client的特征
        self.train_u = np.empty([0])                                  # 第i个，表示第i次的utility

        self.history_long = 15000 # 维护历史样本数据长度 

        self.initial_count = self.get_count_initialization()      # 初始化的轮数

        self.k_sort = np.empty([0])  # 初始化选择的client id

        self.r_intersection_valid_pro = np.empty([0])  # valid区间估计取交集R_t(x)
        self.r_intersection_utility  = np.empty([0])   # utility区间估计取交集R_t(x)

        self.unexplored = set()  # 获取未探索的client_id

        self.candidate_clients = []  # 获取可用的client_id

        self.data_map_file = f'{args.data_map_dir}/train.csv' 
        self.data_map_file_characteristic_normalization = args.data_map_file_characteristic_normalization
        self.data_map_file_characteristic = args.data_map_file_characteristic
        self.client_id_counts =  get_count_entries(self.data_map_file)  # 用一个字典存取client_id的数据条数, key是client_id, value是数据条数

        # self.df = pd.read_csv("/data/dataset/femnist/client_data_mapping/train.csv", delimiter=',', header=0)

        self.accumulate_k_sort = []


    def register_client(self, client_id, feedbacks):
        # Initiate the score for clients. [score, time_stamp, # of trials, size of client, auxi, duration]
        # reward: statistical utility
        # duration: system utility
        # count: times of involved
        # time_stamp: the last time of involved
        # status: whether the client is available

        if client_id not in self.totalClients:
            if client_id not in self.candidate_clients:
                self.candidate_clients.append(client_id)
                self.client_characteristic.append(self.get_characteristic_by_client_id(client_id))

            self.totalClients[client_id] = {}
            self.totalClients[client_id]['reward'] = feedbacks['reward']
            self.totalClients[client_id]['duration'] = feedbacks['duration']
            self.totalClients[client_id]['time_stamp'] = self.training_round
            self.totalClients[client_id]['count'] = 0
            self.totalClients[client_id]['status'] = True
            self.totalClients[client_id]['validity'] = np.random.uniform(0, 1)

            self.totalClients[client_id]['utility_loss'] = np.random.uniform(0, 1)
            self.totalClients[client_id]['utility_efficiency'] = np.random.uniform(0, 1)
            # save the last 10 history utility_efficiency

            self.totalClients[client_id]['history_efficiency'] = []
            self.totalClients[client_id]['history_train_acc'] = []

            self.unexplored.add(client_id)

 
    # 模拟当前轮次，client_id对应的client是否训练成功->本轮validity的值
    def get_round_simulate_validity(self, client_id):
        arr_temp = np.array(self.candidate_clients)
        indices = np.where(arr_temp == client_id)[0][0]
        return get_whether_valid_participant(self.client_characteristic[indices], self.max_client_data_size, self.min_client_data_size)


    # The function to get the validity of the client 注意这里是返回client的区间估计
    def get_validity(self, client_id):
        confidence_interval_valid_pro = self.linear_model(self.train_x_linear, self.train_v, self.candidate, self.client_characteristic)
        return confidence_interval_valid_pro
    
    # The function to get the utility of the client 注意这里是返回client的区间估计
    def get_utility(self):
        confidence_interval_utility = self.gaussian_process_model(self.train_x_gaussian, self.train_u, self.candidate, self.client_characteristic) 
        return confidence_interval_utility

    def update_client_util(self, client_id, feedbacks):
        '''
        @ feedbacks['reward']: statistical utility
        @ feedbacks['duration']: system utility
        @ feedbacks['count']: times of involved
        '''
        self.totalClients[client_id]['reward'] = feedbacks['reward']
        self.totalClients[client_id]['duration'] = feedbacks['duration']
        self.totalClients[client_id]['time_stamp'] = feedbacks['time_stamp']
        self.totalClients[client_id]['count'] += 1
        self.totalClients[client_id]['status'] = feedbacks['status']
        self.totalClients[client_id]['utility_loss'] = feedbacks['utility_loss']
        # TODO: add - 
        self.totalClients[client_id]['utility_efficiency'] = (feedbacks['utility_efficiency'] * 0.5 + 0.5) * 75 #(feedbacks['utility_efficiency'] * 0.5 + 0.5) * 75
        self.totalClients[client_id]['history_efficiency'].append(self.totalClients[client_id]['utility_efficiency'])
        self.totalClients[client_id]['history_train_acc'].append(feedbacks['train_acc'])

        if 'negative' in self.utility_type:
            self.totalClients[client_id]['utility_efficiency'] = (-feedbacks['utility_efficiency'] * 0.5 + 0.5) * 75 
        self.totalClients[client_id]['train_acc'] = (100 - feedbacks['train_acc']) * 0.01

        self.unexplored.discard(client_id)
        # self.successfulClients.add(client_id)

        # 更新GP model的训练样本
        client_id_index = self.candidate_clients.index(client_id)
        # if client_id_index in self.k_sort and self.totalClients[client_id]['validity'] == 1: 
        if client_id_index in self.k_sort:
            self.train_x_gaussian = np.append(self.train_x_gaussian, np.array([self.client_characteristic[client_id_index]]), axis = 0)
            if self.utility_type == 'loss':
                utility = self.totalClients[client_id]['utility_loss']  # 真实的utility
            elif self.utility_type == 'loss_multiply_efficiency':
                utility = self.totalClients[client_id]['utility_loss'] * self.totalClients[client_id]['utility_efficiency']
            elif self.utility_type == 'efficiency':
                utility = self.totalClients[client_id]['utility_efficiency']    
            elif self.utility_type == 'train_acc':
                utility = self.totalClients[client_id]['train_acc'] 
            elif self.utility_type == 'oort_reward': # loss_multiply_system_utility
                client_list = list(self.totalClients.keys())
                self.round_threshold = self.args.round_threshold
                self.alpha = self.args.round_penalty
                sortedDuration = sorted([self.totalClients[key]['duration'] for key in client_list])
                logging.info(f"sortedDuration: {sortedDuration}")
                self.round_prefer_duration = sortedDuration[min(int(len(sortedDuration) * self.round_threshold/100.), len(sortedDuration)-1)]
                utility = self.totalClients[client_id]['utility_loss'] * (self.round_prefer_duration/self.totalClients[client_id]['duration']) * self.alpha
                logging.info(f"client_id: {client_id}, utility: {utility}, duration: {self.totalClients[client_id]['duration']}, round_prefer_duration: {self.round_prefer_duration}")
            elif self.utility_type == 'efficiency_multiply_system_utility' or self.utility_type == 'negative_efficiency_multiply_system_utility':
                client_list = list(self.totalClients.keys())
                self.round_threshold = self.args.round_threshold
                self.alpha = self.args.round_penalty
                sortedDuration = sorted([self.totalClients[key]['duration'] for key in client_list])
                logging.info(f"sortedDuration: {sortedDuration}")
                self.round_prefer_duration = sortedDuration[min(int(len(sortedDuration) * self.round_threshold/100.), len(sortedDuration)-1)]
                utility = self.totalClients[client_id]['utility_efficiency'] * (self.round_prefer_duration/self.totalClients[client_id]['duration']) * self.alpha
                logging.info(f"client_id: {client_id}, utility: {utility}, duration: {self.totalClients[client_id]['duration']}, round_prefer_duration: {self.round_prefer_duration}")
            elif self.utility_type == 'increase_efficiency':
                if len(self.totalClients[client_id]['history_efficiency']) < 2:
                    utility = 1
                else:
                    history_efficiency = self.totalClients[client_id]['history_efficiency']
                    utility = history_efficiency[-1] - history_efficiency[-2]   
            elif self.utility_type == 'increase_train_acc':
                if len(self.totalClients[client_id]['history_train_acc']) < 2:
                    utility = 100
                else:
                    history_train_acc = self.totalClients[client_id]['history_train_acc']
                    utility = history_train_acc[-1] - history_train_acc[-2] 
            elif self.utility_type == 'increase_efficiency_multiply_train_acc':
                if len(self.totalClients[client_id]['history_efficiency']) < 2 or len(self.totalClients[client_id]['history_train_acc']) < 2:
                    utility = 1
                else:
                    history_efficiency = self.totalClients[client_id]['history_efficiency']
                    history_train_acc = self.totalClients[client_id]['history_train_acc']
                    utility = (history_efficiency[-1] - history_efficiency[-2]) * (history_train_acc[-1] - history_train_acc[-2])      
            elif self.utility_type == 'increase_train_acc_mulitply_loss':
                if len(self.totalClients[client_id]['history_train_acc']) < 2:
                    utility = 100 * self.totalClients[client_id]['utility_loss']
                else:
                    history_train_acc = self.totalClients[client_id]['history_train_acc']      
                    utility = (history_train_acc[-1] - history_train_acc[-2]) * self.totalClients[client_id]['utility_loss']
            print('utility: ', utility)
            self.train_u = np.append(self.train_u, np.array([utility]), axis = 0)  #这里如果接的是模拟的utility, np.array(utility)；如果接的是真实的用np.array([utility])



    # add the implementation of selecting the optimal clients
    def select_participant(self, t_round):
        '''
        @ num_of_clients: # of clients selected
        '''
        # viable_clients = feasible_clients if feasible_clients is not None else set([x for x in self.totalClients.keys() if self.totalClients[x]['status']])
        # return self.getTopK(num_of_clients, self.training_round+1, viable_clients)
        # self.accumulate_select[0] = nan  # 去除client 0 
        # self.candidate[0] = 0           # 去除client 0 
        logging.info('k_sort of last round: {}'.format(self.k_sort)) 
        logging.info('candidate: {}'.format(self.candidate)) 
        logging.info('cur_candidate_num: {}'.format(self.cur_candidate_num)) 
        arr_temp = np.array(self.candidate_clients) # 辅助return
        logging.info('select_t_round: {}'.format(t_round)) 
        if t_round <= self.initial_count:
            self.k_sort = get_k_min(self.accumulate_select, self.k_channel)
            for item in self.k_sort:
                self.accumulate_k_sort.append(item)
         
            self.accumulate_k_sort = sorted(self.accumulate_k_sort)
            logging.info(f"accumulate_k_sort: {self.accumulate_k_sort}")
            self.update_train_sample() # 更新每个client被选择的次数以及linear方法和GP方法对应的训练集
            return np.take(arr_temp, self.k_sort)  # k_sort是索引，返回指定索引的client_id  
        else:
            if self.cur_candidate_num == self.k_channel:  # 只剩k个的时候，后面的计算都不需要了，每一轮输出的都是k_sort sum(self.candidate)
                self.k_sort = np.where(self.candidate == 1)[0]
                return np.take(arr_temp, self.k_sort)
            else:
                import datetime
                logging.info(f"get_validity {datetime.datetime.now()}")
                confidence_interval_valid_pro =  self.get_validity(self) # 获得当前每个候选client的valid的置信区间估计
                logging.info(f"get_utility {datetime.datetime.now()}")
                confidence_interval_utility = self.get_utility()  # 获得当前每个候选client的utility的置信区间估计
                logging.info(f"finish {datetime.datetime.now()}")
                if t_round-1 == self.initial_count:
                     self.r_intersection_valid_pro = confidence_interval_valid_pro  # valid区间估计取交集R_t(x)
                     self.r_intersection_utility = confidence_interval_utility      # utility区间估计取交集R_t(x)
                else:
                    self.r_intersection_valid_pro = get_r_intersection(self.candidate,self.r_intersection_valid_pro,confidence_interval_valid_pro)
                    self.r_intersection_utility = get_r_intersection(self.candidate, self.r_intersection_utility,confidence_interval_utility)
                # 后续要找出R_t(x)的最长对角线的候选client，所以把非候选的都重置为0
                for i in range(self.client_num):
                    if self.candidate[i] != 1:
                        self.r_intersection_valid_pro[i] = np.array([0,0])
                        self.r_intersection_utility[i] = np.array([0,0])
                # 这里对于较小的valid_pro, 删除的操作在R_t(x)下进行
                valid_pro_lcb = np.max(self.r_intersection_valid_pro, axis=0)[0]  # 在每个client的R_t(x)中选出LCB中选出最大值
                for i in range(self.client_num):
                    if self.candidate[i] == 0:
                        continue
                    if self.cur_candidate_num <= self.client_num * 0.99:  # 如果client数量少于一半，则不再做这种删除操作
                        break
                    if self.r_intersection_valid_pro[i][1] < valid_pro_lcb:
                        self.candidate[i] = 0
                        self.cur_candidate_num = self.cur_candidate_num - 1
                        self.accumulate_select[i] = nan  # 这样在挑选累积最小的k-1个client时候，不会被计算
                # 通过R_t(x) 分类, 找出 not Pareto optimal
                for i in range(self.client_num):
                    if self.candidate[i] == 0:
                        continue
                    # if self.cur_candidate_num == self.k_channel * 5: # 如果client数量已经只剩K个，则不再做这种删除操作
                    if self.cur_candidate_num <= self.client_num * 0.99: # 如果client数量已经只剩K个，则不再做这种删除操作
                        break
                    if np.any((self.r_intersection_valid_pro[i][1] < self.r_intersection_valid_pro[:, 0]) & (self.r_intersection_utility[i][1] < self.r_intersection_utility[:, 0])):
                        logging.info(f"utility delete: {self.candidate[i]}")
                        self.candidate[i] = 0
                        logging.info(f"self.cur_candidate_num: {self.cur_candidate_num}")
                        self.cur_candidate_num = self.cur_candidate_num - 1
                        logging.info(f"finish self.cur_candidate_num: {self.cur_candidate_num}")
                        self.accumulate_select[i] = nan
                # 选取K个client
                self.k_sort = self.every_round_client_selection()
                self.update_train_sample() # 更新每个client被选择的次数以及linear方法和GP方法对应的训练集
                return np.take(arr_temp, self.k_sort)

    # 更新每个client被选择的次数以及linear and GP model的训练样本，每一轮选择client后，都需要调用这个函数做更新
    # GP model的训练样本 放在了 update_client_util(self, client_id, feedbacks) 中更新
    def update_train_sample(self):
        # 传入这轮选择的client-> k_sort
        for i in self.k_sort:
            self.train_x_linear = np.append(self.train_x_linear, np.array([self.client_characteristic[i]]), axis = 0)
            # whether_valid = get_whether_valid_participant(j)  # 这里接收的是一个数值; 这里还是按照模拟的逻辑，需要修改
            arr_temp = np.array(self.candidate_clients)
            self.totalClients[arr_temp[i]]['validity'] =  get_whether_valid_participant(self.client_characteristic[i], self.max_client_data_size, self.min_client_data_size)
            whether_valid = self.totalClients[arr_temp[i]]['validity']
            self.train_v = np.append(self.train_v, np.array([whether_valid]), axis = 0)
            self.accumulate_select[i] = self.accumulate_select[i] + 1
            logging.info(f'Client {arr_temp[i]} whether_valid is {whether_valid}') 
            # if whether_valid == 1:
            #     self.train_x_gaussian = np.append(self.train_x_gaussian, np.array([self.client_characteristic[i]]), axis = 0)
            #     #utility = get_utility_simulation(j)  # 这里接收的是一个数值;这里还是按照模拟的逻辑，需要修改
            #     if self.utility_type == 'loss':
            #         utility = self.totalClients[arr_temp[i]]['utility_loss']  # 真实的utility
            #     elif self.utility_type == 'loss_multiply_efficiency':
            #         utility = self.totalClients[arr_temp[i]]['utility_loss'] * self.totalClients[arr_temp[i]]['utility_efficiency']
            #     self.train_u = np.append(self.train_u, np.array([utility]), axis = 0)  #这里如果接的是模拟的utility, np.array(utility)；如果接的是真实的用np.array([utility])

    def update_duration(self, client_id, duration):
        if client_id in self.totalClients:
            self.totalClients[client_id]['duration'] = duration

    # linear model for the validity of the client
    def linear_model(self, train_x, train_v, candidate, client_characteristic):
        # train_x: 第i行，表示第i次的选择的最长对角线对应的client的特征
        # train_v: 第i个，表示第i次的valid participant
        # candidate: 第i个元素表示第i个client是否为候选参与者; 默认为1,表示候选参与者
        # client_characteristic: 第i行，表示第i个client的特征
        # Return confidence_interval  第i行表示第i个client对应的置信区间, 只更新候选参与者的置信区间, 其它client默认为0

        delta_pro = 0.02 # Delta参数
        alpha = 1 + math.sqrt(math.log(4/delta_pro)/2)
        lambda_value = 1 # Lambda参数

        #维护一个历史样本数据的长度
        # num_rows_x = train_x.shape[0]
        # if num_rows_x > self.history_long:
        #     train_x = train_x[-self.history_long:]
        #     train_v = train_v[-self.history_long:]

        client_characteristic = np.array(client_characteristic)
        count = len(client_characteristic)  # 获取client个数
        len_characteristic = client_characteristic.shape[1] # 获取特征长度

        h_matrix = lambda_value * np.identity(len_characteristic) + train_x.T.dot(train_x)
        inv_h_matrix = np.linalg.inv(h_matrix)

        theta_estimation = inv_h_matrix.dot(train_x.T.dot(train_v.T))  # 更新对参数theta的估计

        confidence_interval = np.zeros((count,2)) # 置信区间
        for i in range(count):
            if candidate[i] == 0:
                continue
            mu_estimation = client_characteristic[i].dot(theta_estimation.T)
            bound_estimation = alpha*math.sqrt(client_characteristic[i].dot(inv_h_matrix.dot(client_characteristic[i].T)))
            confidence_interval[i] = np.array([mu_estimation - bound_estimation, mu_estimation + bound_estimation])
        
        return confidence_interval  
    
    # Auxiliary to Gaussian Process Model
    def kernel_function(self, client_characteristic, client_characteristic_other, length_scale):
        # Squared Exponential kernel function
        # client_characteristic:表示某个client的特征
        # client_characteristic_other:表示另一个client的特征
        # 返回 Squared Exponential kernel function的值
        dist = np.sum(np.square(client_characteristic-client_characteristic_other))
        return np.exp(-dist/(2 * length_scale ** 2))

    # Auxiliary to Gaussian Process Model
    def kernel_matrix_function(self, train_x, length_scale):
        # train_x: 第i行，表示第i次的选择的最长对角线对应的client的特征
        # length_scale: GP的核函数的尺度参数
        # 返回kernel_matrix, Gaussian Process方法的核函数矩阵
        kernel_matrix = np.exp(-squareform(pdist(np.array(train_x), 'euclidean')**2) / (2 * length_scale ** 2) )  # Squared Exponential kernel
        return kernel_matrix   # 返回Gaussian Process方法的核函数矩阵

    # Gaussian Process model for the utility of the client
    def gaussian_process_model(self,train_x, train_u, candidate, client_characteristic):
        # train_x: 第i行，表示第i次的选择的最长对角线对应的client的特征
        # train_u: 第i个，表示第i次的utility
        # client_characteristic: 第i行，表示第i个client的特征
        # candidate: 第i个元素表示第i个client是否为候选参与者; 默认为1,表示候选参与者
        # 返回confidence_interval 第i行表示第i个client对应的置信区间, 只更新候选参与者的置信区间, 其它client默认为0
        sigma_par = 0.02 # GP的Sigma参数
        length_scale = 3  # GP的核函数的尺度参数
        # len_characteristic = client_characteristic.shape[1] # 获取特征长度
        #维护一个历史样本数据的长度
        num_rows_x = train_x.shape[0]
        if num_rows_x > self.history_long:
            train_x = train_x[-self.history_long:]
            train_u = train_u[-self.history_long:]
        count_train_x = len(train_x) # 获取历史样本个数
        count = len(client_characteristic)  # 获取client个数
        delta_pro = 0.02 # Delta参数
        beta_t_sqrt = math.sqrt(2 * math.log((count * math.pi**2 * count_train_x**2) / 3 / delta_pro))
        confidence_interval = np.zeros((count,2)) # 置信区间
        kernel_matrix = self.kernel_matrix_function(train_x, length_scale) 
        inv_kernel_matrix_positive = np.linalg.inv(kernel_matrix + sigma_par ** 2 * np.eye(count_train_x))
        temp_mu = inv_kernel_matrix_positive.dot(np.array(train_u))
        for i in range(count):
            if candidate[i] == 0:
                continue
            small_k_vector = []
            for j in range(count_train_x):
                small_k_vector.append(self.kernel_function(train_x[j], client_characteristic[i], length_scale))
            small_k_vector_arr = np.array(small_k_vector)
            expectation_mu = small_k_vector_arr.dot(temp_mu)
            var = math.sqrt(1 - small_k_vector_arr.dot(inv_kernel_matrix_positive.dot(small_k_vector_arr)))
            confidence_interval[i] = [expectation_mu - beta_t_sqrt * var, expectation_mu + beta_t_sqrt * var]
        return confidence_interval # confidence_interval 第i行表示第i个client对应的置信区间, 只更新候选参与者的置信区间, 其它client默认为0   

    # initialization
    def get_count_initialization(self):
        # 获取需要初始化的轮数，以保证每个client都被选择至少一次
        if self.client_num % self.k_channel == 0:
            loop_count = self.client_num // self.k_channel
        else:
            loop_count = self.client_num // self.k_channel + 1
        return loop_count
    
    #  Auxiliary to select_participan()
    def every_round_client_selection(self):
        # k_channel: 信道个数
        # candidate: 第i个元素表示第i个client是否为候选参与者; 默认为1,表示候选参与者
        # count: 获取client个数
        # r_intersection_valid_pro:  valid区间估计取交集R_t(x)
        # r_intersection_utility:    utility区间估计取交集R_t(x)
        # accumulate_select: # 第i个元素表示第i个client累计被选择的次数 这里已经被踢出的client，默认为 NAN, 即不会参与比大小
        # return k_sort, 这一轮选择的client, k_channel个元素，每个元素表示第几个client
        if self.cur_candidate_num == self.k_channel:  # sum(self.candidate)
            self.k_sort = np.where(self.candidate==1)[0]  # 如果只剩k个候选client，则提取这K个channel的位置
        else:
            # 计算每个区间矩阵的对角线长度，即w_t(x), 这里实际计算w_t(x)^2 因为只是比大小
            diagonal_length = np.zeros(self.client_num)
            for i in range(self.client_num):
                if self.candidate[i] == 0:  # 不计算非候选client
                    continue
                diagonal_length[i] = (self.r_intersection_valid_pro[i][1] - self.r_intersection_valid_pro[i][0])**2 + (self.r_intersection_utility[i][1] - self.r_intersection_utility[i][0])**2
            argmax_temp = np.where(diagonal_length==np.max(diagonal_length))
            # 最大对角线用于训练
            argmax = argmax_temp[0]  # 取出了最大的对角线对应的client的位置，注意是一个例如这样的形式：[1,2]；可能有多个,这里就取一个;一般情况不会有完全一样大小的
            
            # # 再选k-1个被选取次数最少的client
            # k_minone_sort = get_k_min(self.accumulate_select, self.k_channel-1)
            # # 下面是为了处理最大对角线的client已经属于k-1个被选取次数最少的client的情况，那么直接选k个个被选取次数最少的client即可
            # temp_count = 0
            # for i in k_minone_sort:
            #     if i == argmax[0]:
            #         self.k_sort = get_k_min(self.accumulate_select, self.k_channel)
            #         break
            #     else:
            #         temp_count = temp_count + 1
            # if temp_count == self.k_channel-1:
            #     self.k_sort = np.append(k_minone_sort, [argmax])

            # 修改剩余K-1个的挑选逻辑，选择前K-1个Utility-UCB
            k_minone_sort = np.argsort(self.r_intersection_utility[:, 1])[-(self.k_channel-1):]
            # 下面是为了处理最大对角线的client已经属于k-1个被选取次数最少的client的情况
            temp_count = 0
            for i in k_minone_sort:
                if i == argmax[0]:
                    self.k_sort = np.argsort(self.r_intersection_utility[:, 1])[-self.k_channel:]
                    break
                else:
                    temp_count = temp_count + 1
            if temp_count == self.k_channel-1:
                self.k_sort = np.append(k_minone_sort, [argmax[0]])

        return self.k_sort
    
    # get client characteristic by client_id
    # data_map_file = "/data/dataset/femnist/client_data_mapping/train.csv"
    def get_characteristic_by_client_id(self, client_id):

        client_characteristic = get_client_id_characteristic_norm(client_id, self.data_map_file_characteristic_normalization)
        
        return np.array([client_characteristic['cpu'], client_characteristic['memory'], client_characteristic['gpu'], client_characteristic['datasize'], client_characteristic['batchsize'], client_characteristic['learningrate']])
    
