import numpy as np
import random
import pandas as pd


def get_k_min(array, k):
    # array: 数组
    # k: 数量
    # 返回数组array最小的K个数据的下标
    k_sort = np.argpartition(array, k)[:k]  # 最小的k个数据的下标
    return k_sort

def intersection(a1, b1, a2, b2):
    # 两个区间取交集，辅助于方法 get_r_intersection
    if b1 < a2 or b2 < a1:
        return None  # 两个区间没有交集，返回 None
    else:
        left = max(a1, a2)
        right = min(b1, b2)
        if left > right:
            return None  # 交集为空，返回 None
        else:
            return np.array([left, right])  # 返回交集的左右端点

def get_r_intersection(candidate, r_intersection, confidence_interval):
    # 对置信区间取交集，即获得R_t(x)
    count = len(r_intersection)
    for i in range(count):
        if candidate[i] == 0:
            continue
        temp_intersection = intersection(r_intersection[i][0], r_intersection[i][1], confidence_interval[i][0], confidence_interval[i][1])
        if temp_intersection is None:
            r_intersection[i] = confidence_interval[i]
        else:
            r_intersection[i] = temp_intersection
    return r_intersection

# numberical simulation
def get_utility_simulation(client_i):
    # client_i: 第i个client, 从0开始计数
    if client_i < 20:
        utility_mu = 0.995 - client_i * 0.005
        utility_var = 0.01
    elif client_i < 220:
        utility_mu = 0.1 + (client_i-20) * (0.8-0.1)/200
        utility_var = 0.015 - (client_i-20) * (0.015-0.005)/200   
    elif client_i < 370:
        utility_mu = 0.8 - (client_i-220) * (0.8-0.1)/150
        utility_var = 0.005 + (client_i-220) * (0.01-0.005)/150
    else:
        utility_mu = 0.4 + (client_i-370) * (0.8-0.4)/150
        utility_var = 0.001 + (client_i-370) * (0.015-0.001)/150
    return np.random.normal(utility_mu, utility_var, 1)


def get_max_min_entry_counts(file_path):
    # data_map_file = "/data/dataset/femnist/client_data_mapping/train.csv"
    # 读取 CSV 文件
    df = pd.read_csv(file_path, delimiter=',')

    # 统计不同的 client_id 的数据条目数
    client_id_counts = df['client_id'].value_counts()

    print("最大数据条目数:", max_count)
    print("最小数据条目数:", min_count)

    # 获取最大和最小的数据条目数
    max_count = client_id_counts.max()
    min_count = client_id_counts.min()

    return max_count, min_count

def get_count_entries(csv_file):
    # 用于读取 CSV 文件并统计每个 client_id 的数据条目数，并将结果存储在字典中
    # 读取 CSV 文件，并指定第一行作为列名
    df = pd.read_csv(csv_file, delimiter=',', header=0)

    # 将 client_id 列转换为整数类型
    df['client_id'] = df['client_id'].astype(str).str.strip()  # 去除两端空格并转换为字符串类型
    df['client_id'] = df['client_id'].astype(int)

    # 统计每个 client_id 的数据条目数，并转换为字典
    client_id_counts = df['client_id'].value_counts().to_dict()
    
    return client_id_counts