import numpy as np
import random
import pandas as pd
import logging
import math


def get_k_min(array, k):
    # array: 数组
    # k: 数量
    # 返回数组array最小的K个数据的下标
    k_sort = np.argpartition(array, k)[:k]  # 最小的k个数据的下标
    return k_sort

def intersection(a1, b1, a2, b2):
    # 两个区间取交集，辅助于方法 get_r_intersection
    if b1 < a2 or b2 < a1:
        return None  # 两个区间没有交集，返回 None
    else:
        left = max(a1, a2)
        right = min(b1, b2)
        if left > right:
            return None  # 交集为空，返回 None
        else:
            return np.array([left, right])  # 返回交集的左右端点

def get_r_intersection(candidate, r_intersection, confidence_interval):
    # 对置信区间取交集，即获得R_t(x)
    count = len(r_intersection)
    for i in range(count):
        if candidate[i] == 0:
            continue
        temp_intersection = intersection(r_intersection[i][0], r_intersection[i][1], confidence_interval[i][0], confidence_interval[i][1])
        if temp_intersection is None:
            r_intersection[i] = confidence_interval[i]
        else:
            r_intersection[i] = temp_intersection
    return r_intersection



# numberical simulation
def get_client_characteristic(count):
    # count: client的个数
    client_characteristic = np.zeros((count, 6))
    # 生成cpu的备选集
    cpu_list = []
    for i in range(51):
         cpu_list.append(50+i*2) # 范围[50, 150] 核数，偶数
    # 生成memory的备选集
    memory_list = [128, 256, 512, 1024] 
    # 生成gpu的备选集
    gpu_list = [2,4,6,8]
    # 生成datasize的备选集
    datasize_list = []
    for i in range(20):
         datasize_list.append(0.5+i*0.5)  # 范围[0.5, 10] GB
    # 生成batch size的备选集
    batchsize_list = [128, 256, 512, 1024, 2048]
    # 生成learning rate的备选集
    learningrate_list = [0.0001, 0.001, 0.01, 0.1]
    
    for i in range(count):
        if i < 20:
            client_characteristic[i][0] = random.sample(cpu_list[45:51], 1)[0] 
            client_characteristic[i][1] = random.sample(memory_list[2:4], 1)[0] 
            client_characteristic[i][2] = random.sample(gpu_list[2:4], 1)[0] 
            client_characteristic[i][3] = random.sample(datasize_list[18:20], 1)[0] 
            client_characteristic[i][4] = random.sample(batchsize_list[4:5], 1)[0] 
            client_characteristic[i][5] = random.sample(learningrate_list[0:2], 1)[0] 
        # 第一类
        elif i < 70:
            client_characteristic[i][0] = random.sample(cpu_list[0:10], 1)[0] 
            client_characteristic[i][1] = random.sample(memory_list[0:3], 1)[0] 
            client_characteristic[i][2] = random.sample(gpu_list[0:2], 1)[0] 
            client_characteristic[i][3] = random.sample(datasize_list[0:5], 1)[0] 
            client_characteristic[i][4] = random.sample(batchsize_list[0:2], 1)[0] 
            client_characteristic[i][5] = random.sample(learningrate_list[2:3], 1)[0] 
        elif i < 120:
            client_characteristic[i][0] = random.sample(cpu_list[10:20], 1)[0] 
            client_characteristic[i][1] = random.sample(memory_list[1:3], 1)[0] 
            client_characteristic[i][2] = random.sample(gpu_list[0:1], 1)[0] 
            client_characteristic[i][3] = random.sample(datasize_list[5:10], 1)[0] 
            client_characteristic[i][4] = random.sample(batchsize_list[1:3], 1)[0] 
            client_characteristic[i][5] = random.sample(learningrate_list[1:4], 1)[0] 
        elif i < 170:
            client_characteristic[i][0] = random.sample(cpu_list[20:35], 1)[0] 
            client_characteristic[i][1] = random.sample(memory_list[2:4], 1)[0] 
            client_characteristic[i][2] = random.sample(gpu_list[1:2], 1)[0] 
            client_characteristic[i][3] = random.sample(datasize_list[10:15], 1)[0] 
            client_characteristic[i][4] = random.sample(batchsize_list[2:4], 1)[0] 
            client_characteristic[i][5] = random.sample(learningrate_list[0:4], 1)[0] 
        elif i < 220:
            client_characteristic[i][0] = random.sample(cpu_list[35:49], 1)[0] 
            client_characteristic[i][1] = random.sample(memory_list[3:4], 1)[0] 
            client_characteristic[i][2] = random.sample(gpu_list[2:3], 1)[0] 
            client_characteristic[i][3] = random.sample(datasize_list[15:20], 1)[0] 
            client_characteristic[i][4] = random.sample(batchsize_list[2:4], 1)[0] 
            client_characteristic[i][5] = random.sample(learningrate_list[0:4], 1)[0] 
        # 第二类
        elif i < 370:
            client_characteristic[i][0] = random.sample(cpu_list[20:30], 1)[0] 
            client_characteristic[i][1] = random.sample(memory_list[1:3], 1)[0] 
            client_characteristic[i][2] = random.sample(gpu_list[1:3], 1)[0] 
            client_characteristic[i][3] = random.sample(datasize_list[5:15], 1)[0] 
            client_characteristic[i][4] = random.sample(batchsize_list[1:4], 1)[0] 
            client_characteristic[i][5] = random.sample(learningrate_list[1:3], 1)[0] 
        # 第三类
        else:
            client_characteristic[i][0] = random.sample(cpu_list[0:10], 1)[0] 
            client_characteristic[i][1] = random.sample(memory_list[0:2], 1)[0] 
            client_characteristic[i][2] = random.sample(gpu_list[0:2], 1)[0] 
            client_characteristic[i][3] = random.sample(datasize_list[15:20], 1)[0] 
            client_characteristic[i][4] = random.sample(batchsize_list[3:5], 1)[0] 
            client_characteristic[i][5] = random.sample(learningrate_list[2:4], 1)[0] 
    
    # 归一化
    for i in range(count):
        client_characteristic[i][0] = client_characteristic[i][0]/150
        client_characteristic[i][1] = client_characteristic[i][1]/1024
        client_characteristic[i][2] = client_characteristic[i][2]/8
        client_characteristic[i][3] = client_characteristic[i][3]/10
        client_characteristic[i][4] = client_characteristic[i][4]/2048
        client_characteristic[i][5] = client_characteristic[i][5]/0.1
    
    return client_characteristic

# numberical simulation
# def get_whether_valid_participant(client_i):
#     # client_i: 第i个client, 从0开始计数
#     if client_i < 20:
#         valid_mu = 0.850 + client_i * 0.005
#     elif client_i < 220:
#         valid_mu = 0.1 + (client_i-20) * (0.8-0.1)/200
#     elif client_i < 370:
#         valid_mu = 0.45 + (client_i-220) * (0.8-0.45)/150
#     else:
#         valid_mu = 0.45 - (client_i-370) * (0.45-0.1)/150
#     return np.random.binomial(1, valid_mu)

# numberical simulation
def get_utility_simulation(client_i):
    # client_i: 第i个client, 从0开始计数
    if client_i < 20:
        utility_mu = 0.995 - client_i * 0.005
        utility_var = 0.01
    elif client_i < 220:
        utility_mu = 0.1 + (client_i-20) * (0.8-0.1)/200
        utility_var = 0.015 - (client_i-20) * (0.015-0.005)/200   
    elif client_i < 370:
        utility_mu = 0.8 - (client_i-220) * (0.8-0.1)/150
        utility_var = 0.005 + (client_i-220) * (0.01-0.005)/150
    else:
        utility_mu = 0.4 + (client_i-370) * (0.8-0.4)/150
        utility_var = 0.001 + (client_i-370) * (0.015-0.001)/150
    return np.random.normal(utility_mu, utility_var, 1)





def get_max_min_entry_counts(file_path):
    # data_map_file = "/data/dataset/femnist/client_data_mapping/train.csv"
    # 读取 CSV 文件
    df = pd.read_csv(file_path, delimiter=',')

    # 统计不同的 client_id 的数据条目数
    client_id_counts = df['client_id'].value_counts()

    print("最大数据条目数:", max_count)
    print("最小数据条目数:", min_count)

    # 获取最大和最小的数据条目数
    max_count = client_id_counts.max()
    min_count = client_id_counts.min()

    return max_count, min_count

def get_count_entries(csv_file):
    # 用于读取 CSV 文件并统计每个 client_id 的数据条目数，并将结果存储在字典中
    # 读取 CSV 文件，并指定第一行作为列名
    df = pd.read_csv(csv_file, delimiter=',', header=0)

    # 将 client_id 列转换为整数类型
    df['client_id'] = df['client_id'].astype(str).str.strip()  # 去除两端空格并转换为字符串类型
    df['client_id'] = df['client_id'].astype(int)

    # 统计每个 client_id 的数据条目数，并转换为字典
    client_id_counts = df['client_id'].value_counts().to_dict()
    
    return client_id_counts




