from collections import OrderedDict
from .utils.auxiliary_to_random_mode import *

import numpy as np

from auxiliary_utils.client_characteristic import *


def create_training_selector_with_random(args):
    return _training_selector(args)

class _training_selector(object):
    """random's training selector
    """
    def __init__(self, args, sample_seed=233):

        np.random.seed(sample_seed)

        self.totalClients = OrderedDict()
        self.training_round = 0

        self.args = args

        self.utility_type = args.utility_type

        self.k_channel = args.k_channel  # acquire the number of selected clients in every round
        self.client_num = args.client_num   # acquire the number of all Clients
        self.max_client_data_size = args.max_client_data_size
        self.min_client_data_size = args.min_client_data_size
        self.len_characteristic = 6  # acquire the length of client characteristic 
        self.client_characteristic = []

        self.candidate = np.ones(self.client_num) # The i-th element indicates whether the i-th client is a candidate participant; the default is 1, indicating a candidate participant
        self.cur_candidate_num = args.client_num
        self.accumulate_select = np.zeros(self.client_num) # 每个client累积被选取的次数

        self.train_x_linear = np.empty([0,self.len_characteristic])   # 第i行，表示第i次的选择的最长对角线对应的client的特征
        self.train_v = np.empty([0])                                  # 第i个，表示第i次的valid participant

        self.train_x_gaussian = np.empty([0,self.len_characteristic]) # 第i行，表示第i次的选择的最长对角线对应的client的特征
        self.train_u = np.empty([0])                                  # 第i个，表示第i次的utility

        self.history_long = 15000 # 维护历史样本数据长度

        self.k_sort = np.empty([0])  # 初始化选择的client id

        self.r_intersection_valid_pro = np.empty([0])  # valid区间估计取交集R_t(x)
        self.r_intersection_utility  = np.empty([0])   # utility区间估计取交集R_t(x)

        self.unexplored = set()  # 获取未探索的client_id

        self.candidate_clients = []  # 获取可用的client_id

        self.data_map_file = f'{args.data_map_dir}/train.csv' 
        self.data_map_file_characteristic_normalization = args.data_map_file_characteristic_normalization
        self.data_map_file_characteristic = args.data_map_file_characteristic
        self.client_id_counts =  get_count_entries(self.data_map_file)  # 用一个字典存取client_id的数据条数, key是client_id, value是数据条数
        self.accumulate_k_sort = []


    def register_client(self, client_id, feedbacks):
        # Initiate the score for clients. [score, time_stamp, # of trials, size of client, auxi, duration]
        # reward: statistical utility
        # duration: system utility
        # count: times of involved
        # time_stamp: the last time of involved
        # status: whether the client is available

        if client_id not in self.totalClients:
            if client_id not in self.candidate_clients:
                self.candidate_clients.append(client_id)
                self.client_characteristic.append(self.get_characteristic_by_client_id(client_id))

            self.totalClients[client_id] = {}
            self.totalClients[client_id]['reward'] = feedbacks['reward']
            self.totalClients[client_id]['duration'] = feedbacks['duration']
            self.totalClients[client_id]['time_stamp'] = self.training_round
            self.totalClients[client_id]['count'] = 0
            self.totalClients[client_id]['status'] = True
            clinet_index = self.candidate_clients.index(client_id)
            self.totalClients[client_id]['validity'] = get_whether_valid_participant(self.client_characteristic[clinet_index], self.max_client_data_size, self.min_client_data_size)

            self.unexplored.add(client_id)

 
    # 模拟当前轮次，client_id对应的client是否训练成功->本轮validity的值
    def get_round_simulate_validity(self, client_id):
        arr_temp = np.array(self.candidate_clients)
        indices = np.where(arr_temp == client_id)[0][0]
        return get_whether_valid_participant(self.client_characteristic[indices], self.max_client_data_size, self.min_client_data_size)


    # The function to get the validity of the client 注意这里是返回client的区间估计
    def get_validity(self, client_id):
        confidence_interval_valid_pro = self.linear_model(self.train_x_linear, self.train_v, self.candidate, self.client_characteristic)
        return confidence_interval_valid_pro
    
    # get client characteristic by client_id
    # data_map_file = "/data/dataset/femnist/client_data_mapping/train.csv"
    def get_characteristic_by_client_id(self, client_id):

        client_characteristic = get_client_id_characteristic_norm(client_id, self.data_map_file_characteristic_normalization)
        
        return np.array([client_characteristic['cpu'], client_characteristic['memory'], client_characteristic['gpu'], client_characteristic['datasize'], client_characteristic['batchsize'], client_characteristic['learningrate']])

