from collections import Counter
import os
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
import time
import warnings
import pandas as pd
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
from aif360.sklearn.metrics import generalized_entropy_error
from fairlearn.metrics import demographic_parity_difference, equalized_odds_difference
import argparse
parser = argparse.ArgumentParser(description='FairGBFC')
parser.add_argument('--sen_index', type=int, default=1, help='index of sensitive attribute')
args = parser.parse_args()

class GranularBall:
    #Some basic attributes of GB
    def __init__(self, data) -> None:
        self.data = data
        self.center, self.radius = calCenterRadius(data[:, 1:])
        self.num = len(data)
        self.label, self.purity = calLabelPurity(data)
        self.ps1, self.ps2 = compute_ps(data, args.sen_index)

def compute_ps(data, sen_index):
    labels = data[:,0]#样本标签
    # 获取敏感属性为 1 和 2 的索引
    sensitive_values = data[:, sen_index]
    idx_1 = np.where(sensitive_values == 0)[0]
    idx_2 = np.where(sensitive_values == 1)[0]

    ps1 = np.mean(labels[idx_1] == 1) if len(idx_1) > 0 else 0.0
    ps2 = np.mean(labels[idx_2] == 1) if len(idx_2) > 0 else 0.0

    return ps1, ps2

#Calculate the label and purity of GB
def calLabelPurity(data):
    if len(data) > 1:
        count = Counter(data[:, 0])
        label = max(count, key=count.get) 
        purity = count[label] / len(data)
    else:
        label = data[0][0]
        purity = 1.0
    return label, purity


#Calculate the center and radius of GB
def calCenterRadius(data):
    center = np.mean(data, axis=0)
    dis = calculateDist(data, center)
    radius = np.mean(dis)
    return center, radius

#Calculate the Euclidean distance between objects
def calculateDist(A, B, flag=0):
    if (flag == 0):
        return np.sqrt(np.sum((A - B)**2, axis=1))
    else:
        return np.sqrt(np.sum((A - B)**2))

#Constructing clusters formed by the majority of samples in data
def generateOmegaCluster(data):
    cluster = []# 包含属于主要类别的样本
    todo_data = []# 包含不属于主要类别的样本
    count = Counter(data[:, 0])
    # 找出出现次数最多的类别标签
    label = max(count, key=count.get)  
    # 提取所有属于该主要类别的样本
    Omega_Group = data[data[:, 0] == label]
    # 计算主要类别样本的特征均值(中心点)
    center = np.mean(Omega_Group[:, 1:], axis=0)
    # 计算所有样本到中心点的距离
    dis = np.around(calculateDist(data[:, 1:], center), 6)
    # 提取主要类别样本的距离值
    Omega_Group_dis = dis[data[:, 0] == label]
    # 计算主要类别样本的平均距离作为半径
    radius = np.around(np.mean(Omega_Group_dis), 6)
    # 遍历所有样本，如果样本距离小于等于半径，加入新粒球（簇），否则加入待处理数据
    for i in range(len(data)):
        if dis[i] <= radius:
            cluster.append(data[i])
        else:
            todo_data.append(data[i])
    cluster = np.array(cluster)
    todo_data = np.array(todo_data)
    # 返回生成的簇和剩余数据
    return cluster, todo_data


# Eliminate conflicting relationships：merging heterogeneous nested GBs
def removeConflicts(ball_list):
    gb_remove_tmp = []
    for i in range(len(ball_list) - 1):
        if ball_list[i] in gb_remove_tmp:
            continue
        for j in range(i + 1, len(ball_list)):
            if (ball_list[j] in gb_remove_tmp):
                continue
            #如果两个球是异质且粒球中心过于近，则合并两个粒球，然后重新计算纯度和标签和中心、半径，同质的话就不用合并。
            if (ball_list[i].label != ball_list[j].label) and (calculateDist(
                    ball_list[i].center, ball_list[j].center,
                    flag=1) <= abs(ball_list[i].radius - ball_list[j].radius)):
                ball_list[j].data = np.concatenate(
                    (ball_list[j].data, ball_list[i].data))
                ball_list[j].label, ball_list[j].purity = calLabelPurity(
                    ball_list[j].data)
                ball_list[j].center, ball_list[j].radius = calCenterRadius(
                    ball_list[j].data[:, 1:])
                gb_remove_tmp.append(ball_list[i])
                break
    for ball in set(gb_remove_tmp):
        ball_list.remove(ball)
    return ball_list

def load_data(name):
    src_path = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(src_path, '..', 'Datasets/processed dataset/')
    name = name + '.csv'
    print(dataset_path)
    # 设置敏感属性索引
    if name == 'credit_approval.csv':  # Gender
        sen_index = 1
        #在属性中的列下标
        cat_indices = [0, 3, 4, 5, 6, 8, 9, 11, 12]
        cont_indices = [1, 2, 7, 10, 13, 14]
    elif name == 'adult.csv':  # sex
        sen_index = 8
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'credit_default.csv':  # SEX
        sen_index = 2#这是包括了标签的
        cat_indices = [1, 2, 3, 5, 6, 7, 8, 9, 10]
        cont_indices = [0, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    elif name == 'law_admission.csv':  # gender
        sen_index = 6
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]

    elif name == 'german.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
        cont_indices = [3, 4, 5]
    
    elif name == 'por.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]

    # 加载数据
    data_path = os.path.join(dataset_path, name)
    if not os.path.exists(data_path):
         print(f"Error: File not found at {data_path}")
         return None, None, None, None
    data = pd.read_csv(data_path)

    label_col = data.iloc[:, 0]
    sensi_col = data.iloc[:, sen_index]

    # 输出统计信息
    print(f"\nDataset: {name}")
    
    # 敏感属性比例
    sen_col_ratio = sensi_col.value_counts(normalize=True)
    print(f"\n✅ 敏感属性列的值比例：")
    print(sen_col_ratio)

    # 标签分布
    col0_ratio = label_col.value_counts(normalize=True)
    print("\n✅ 标签的值比例：")
    print(col0_ratio)

    # DP值计算
    pos_label = 1
    groups = sensi_col.unique()
    if len(groups) >= 2:
        g1, g2 = groups[0], groups[1]
        g1_data = data[sensi_col == g1]
        g2_data = data[sensi_col == g2]

        g1_pos_rate = (g1_data.iloc[:, 0] == pos_label).sum() / len(g1_data)
        g2_pos_rate = (g2_data.iloc[:, 0] == pos_label).sum() / len(g2_data)
        dp = abs(g1_pos_rate - g2_pos_rate)
        print(f"\n✅ DP值（|P({g1}, 正例) - P({g2}, 正例)|）：{dp:.6f}")

        # 输出更详细的统计信息
        g1_total = len(g1_data)
        g1_pos = (g1_data.iloc[:, 0] == pos_label).sum()
        g2_total = len(g2_data)
        g2_pos = (g2_data.iloc[:, 0] == pos_label).sum()
        pos_total = (label_col == pos_label).sum()
        neg_total = (label_col != pos_label).sum()

        print(f"\n🔍 样本详细统计：")
        print(f" 敏感属性 - {g1} 样本总数：{g1_total}")
        print(f" 敏感属性 - {g1} 中正样本数：{g1_pos}")
        print(f" 敏感属性 - {g2} 样本总数：{g2_total}")
        print(f" 敏感属性 - {g2} 中正样本数：{g2_pos}")
        print(f"  - 正类样本总数：{pos_total}")
        print(f"  - 负类样本总数：{neg_total}")
    
    # 返回 Data 以及 元数据
    return data, sen_index - 1, cat_indices, cont_indices



#For any GB, iteratively split it
# Returns the GranularBall class equipped with basic attributes
def splitGB(gb):
    # 初始化待处理数据为输入的粒球数据
    todo_data = gb
    # 计算当前数据中不同标签的数量
    label_num = len(np.unique(gb[:, 0]))
    # 获取待处理数据的样本数量
    todo_data_num = len(todo_data)
    ball_list = []# 用于存储生成的粒球
    ball_list_new = []# 用于存储处理冲突后的最终粒球
    #当标签数量不等于样本数量时继续分裂（即数据不纯）
    while label_num != todo_data_num:
        cluster, todo_data_tmp = generateOmegaCluster(todo_data)
        # 如果生成的簇包含多于一个样本，用该簇创建一个新粒球
        if len(cluster) > 1:  
            new_ball = GranularBall(cluster)
            ball_list.append(new_ball)
        # 更新待处理数据为剩余数据
        todo_data = todo_data_tmp
        todo_data_num = len(todo_data)
        # 如果剩余数据多于一个样本，重新计算剩余数据中的标签数量
        if todo_data_num > 1:
            label_num = len(np.unique(todo_data[:, 0]))
        else:
            break
    if len(ball_list) > 1:  #Eliminate conflicts
        ball_list_new = removeConflicts(ball_list)
    return ball_list_new


# Iteratively construct GBs that meet threshold conditions for the entire training dataset
def generateGBList(data, purity):
    # 初始化一个粒球列表，第一个粒球包含所有输入数据
    GB_List = [GranularBall(data)] 
    i = 0  #cursor
    ## 获取当前粒球列表的长度
    GB_num = len(GB_List)
    while True:
        # 检查当前粒球的纯度是否低于阈值
        if GB_List[i].purity < purity:
            # 如果纯度不足，尝试分裂当前粒球
            new_split_gbs = splitGB(GB_List[i].data)  
            # 情况1：分裂得到多个新粒球(通常为2个)
            if len(new_split_gbs) > 1:  
                # 用第一个新粒球替换原粒球，将剩余的新粒球添加到列表末尾
                GB_List[i] = new_split_gbs[0]
                GB_List.extend(new_split_gbs[1:])
            # 情况2：分裂只得到一个粒球，但与原粒球不同
            elif len(new_split_gbs) == 1 and (len(new_split_gbs[0].data)
                                              != len(GB_List[i].data)):
                # 移除原粒球，并添加新粒球到列表末尾
                GB_List.pop(i)
                GB_List.append(new_split_gbs[0])
            else:# 情况3：分裂失败(得到与原粒球相同的粒球)，直接移除不满足条件的粒球
                GB_List.pop(i)
            # 更新当前粒球列表长度
            GB_num = len(GB_List)
        else:# 如果当前粒球纯度达标
            i += 1
            # 循环终止条件：当游标到达列表末尾时
        if i == GB_num:
            break
    return GB_List

#Determine the label of the test sample
def GB_KNN1(X_test, Ball_list):
    predict_label = []
    ball_centers = []
    ball_num_list = []
    for ball in Ball_list:
        ball_centers.append(ball.center)
        ball_num_list.append(ball.num)
    samp_all = np.sum(ball_num_list)
    ball_density_list = ball_num_list / samp_all
    for row in X_test:
        dis = (calculateDist(row, ball_centers) - ball_density_list).tolist()
        predict_ball = Ball_list[dis.index(min(dis))]
        predict_label.append(predict_ball.label)
    return predict_label, samp_all

def compute_weights(x, gb_centers, gb_radiuss, sigma):
    '''
    x: 单个测试样本，形状为 (d,)
    gb_centers: 所有球的中心，形状为 (n_balls, d)
    gb_radiuss: 所有球的半径，形状为 (n_balls,)
    sigma: 超参数，用于控制高斯函数宽度
    '''

    #计算距离并减去半径
    dists = calculateDist(x, gb_centers) - gb_radiuss #(n_balls,)

    #计算权重每个粒球可以给x提供的权重
    weights = np.exp(- (dists ** 2) / (2 * sigma ** 2) ) 

    #对权重归一化
    weights /= np.sum(weights)

    return weights


def load_data(name):
    src_path = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(src_path, '..', 'Datasets/processed dataset/')
    name = name + '.csv'
    print(dataset_path)
    # 设置敏感属性索引
    if name == 'credit_approval.csv':  # Gender
        sen_index = 1
        #在属性中的列下标
        cat_indices = [0, 3, 4, 5, 6, 8, 9, 11, 12]
        cont_indices = [1, 2, 7, 10, 13, 14]
    elif name == 'adult.csv':  # sex
        sen_index = 8
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'adult1.csv':  # race
        sen_index = 7# race
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'bail.csv':  # WHITE
        sen_index = 1
        cat_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        cont_indices = [10, 11, 12, 13, 14]
    elif name == 'bank.csv':  # age
        sen_index = 1
        cat_indices = [0, 1, 2, 3, 4, 6, 7, 8, 13]
        cont_indices = [5, 9, 10, 11, 12]
    elif name == 'compas.csv':  # race
        sen_index = 3
        cat_indices = [1, 2, 3]
        cont_indices = [0, 4, 5, 6]
    elif name == 'credit_default.csv':  # SEX
        sen_index = 2#这是包括了标签的
        cat_indices = [1, 2, 3, 5, 6, 7, 8, 9, 10]
        cont_indices = [0, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    elif name == 'law_admission.csv':  # gender
        sen_index = 6
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'law_admission1.csv':  # race
        sen_index = 7
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'recruitment.csv':  # Gender
        sen_index = 2
        cat_indices = [1, 2, 9]
        cont_indices = [0, 3, 4, 5, 6, 7, 8]
    elif name == 'thyroid.csv':  # Gender
        sen_index = 2
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
        cont_indices = [0]
    elif name == 'german.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
        cont_indices = [3, 4, 5]
    elif name == 'math.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]
    
    elif name == 'por.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]

    elif name == 'meps.csv':  # RACE
        sen_index = 2
        cat_indices = [1] + list(range(5, 138))
        cont_indices = [0, 2, 3, 4]

    elif name == 'diabetic.csv':  # race
        sen_index = 1
        cat_indices = list(range(0, 5)) + list(range(11, 23))
        cont_indices = [6, 7, 8, 9, 10]
    # 加载数据
    if name == 'law_admission1.csv':
        name = 'law_admission.csv'
    if name == 'credit_approval1.csv':
        name = 'credit_approval.csv'
    if name == 'adult1.csv':
        name = 'adult.csv'
    data_path = os.path.join(dataset_path, name)
    if not os.path.exists(data_path):
         print(f"Error: File not found at {data_path}")
         return None, None, None, None
    data = pd.read_csv(data_path)

    label_col = data.iloc[:, 0]
    sensi_col = data.iloc[:, sen_index]

    # 输出统计信息
    print(f"\nDataset: {name}")
    
    # 敏感属性比例
    sen_col_ratio = sensi_col.value_counts(normalize=True)
    print(f"\n✅ 敏感属性列的值比例：")
    print(sen_col_ratio)

    # 标签分布
    col0_ratio = label_col.value_counts(normalize=True)
    print("\n✅ 标签的值比例：")
    print(col0_ratio)

    # DP值计算
    pos_label = 1
    groups = sensi_col.unique()
    if len(groups) >= 2:
        g1, g2 = groups[0], groups[1]
        g1_data = data[sensi_col == g1]
        g2_data = data[sensi_col == g2]

        g1_pos_rate = (g1_data.iloc[:, 0] == pos_label).sum() / len(g1_data)
        g2_pos_rate = (g2_data.iloc[:, 0] == pos_label).sum() / len(g2_data)
        dp = abs(g1_pos_rate - g2_pos_rate)
        print(f"\n✅ DP值（|P({g1}, 正例) - P({g2}, 正例)|）：{dp:.6f}")

        # 输出更详细的统计信息
        g1_total = len(g1_data)
        g1_pos = (g1_data.iloc[:, 0] == pos_label).sum()
        g2_total = len(g2_data)
        g2_pos = (g2_data.iloc[:, 0] == pos_label).sum()
        pos_total = (label_col == pos_label).sum()
        neg_total = (label_col != pos_label).sum()

        print(f"\n🔍 样本详细统计：")
        print(f" 敏感属性 - {g1} 样本总数：{g1_total}")
        print(f" 敏感属性 - {g1} 中正样本数：{g1_pos}")
        print(f" 敏感属性 - {g2} 样本总数：{g2_total}")
        print(f" 敏感属性 - {g2} 中正样本数：{g2_pos}")
        print(f"  - 正类样本总数：{pos_total}")
        print(f"  - 负类样本总数：{neg_total}")
    
    # 返回 Data 以及 元数据
    return data, sen_index - 1, cat_indices, cont_indices

def compute_DP(X_test, pred_test, sen_index):
    #提取每个样本的敏感属性值
    sent_index_list = X_test[:, sen_index - 1]

    #找出敏感属性值等于1和等于2的样本索引
    idx_1 = np.where(sent_index_list == 0)[0]
    idx_2 = np.where(sent_index_list == 1)[0]
    
    #分别计算不同敏感属性下的阳性概率均值
    pred_test = np.array(pred_test)
    mean_1 = np.mean(pred_test[idx_1]) if len(idx_1) > 0 else 0.0
    mean_2 = np.mean(pred_test[idx_2]) if len(idx_2) > 0 else 0.0
    #计算DP差值
    dp_value = np.abs(mean_1 - mean_2)
    return dp_value

def compute_DP_AOD_EOD_hard(X_test, Y_test, pred_label, sen_index):
    """
    X_test: 特征数据，二维数组
    Y_test: 真实标签，0/1 一维数组
    pred_label: 预测标签，0/1 一维数组（硬分类）
    sen_index: 敏感属性列索引，从1开始

    返回：
    dp_value, abs_dp_value,
    aod_value, abs_aod_value,
    eod_value, abs_eod_value
    """
    pred_label = np.array(pred_label)
    sen_attr = X_test[:, sen_index - 1]

    idx_0 = np.where(sen_attr == 0)[0]
    idx_1 = np.where(sen_attr == 1)[0]

    # 计算 DP：不同群体预测为正的比例差
    mean_pred_0 = np.mean(pred_label[idx_0]) if len(idx_0) > 0 else 0.0
    mean_pred_1 = np.mean(pred_label[idx_1]) if len(idx_1) > 0 else 0.0
    dp_value = mean_pred_0 - mean_pred_1
    abs_dp_value = abs(dp_value)

    def get_TPR_FPR(y_true, y_pred):
        TP = np.sum((y_true == 1) & (y_pred == 1))
        FN = np.sum((y_true == 1) & (y_pred == 0))
        FP = np.sum((y_true == 0) & (y_pred == 1))
        TN = np.sum((y_true == 0) & (y_pred == 0))
        TPR = TP / (TP + FN + 1e-6)
        FPR = FP / (FP + TN + 1e-6)
        return TPR, FPR

    TPR_0, FPR_0 = get_TPR_FPR(Y_test[idx_0], pred_label[idx_0]) if len(idx_0) > 0 else (0.0, 0.0)
    TPR_1, FPR_1 = get_TPR_FPR(Y_test[idx_1], pred_label[idx_1]) if len(idx_1) > 0 else (0.0, 0.0)

    # 计算 EOD 和 AOD
    eod_value = TPR_0 - TPR_1
    abs_eod_value = abs(eod_value)

    aod_value = 0.5 * ((TPR_0 - TPR_1) + (FPR_0 - FPR_1))
    abs_aod_value = abs(aod_value)

    return abs_dp_value, abs_aod_value, abs_eod_value

def main():
    warnings.filterwarnings("ignore")  # ignore warning
    # filenames = ['credit_approval', 'adult', 'bail', 'bank', 'credit_default', 'law_admission','recruitment', 'compas', 'thyroid','german', 'por', 'math']
    filenames = ['law_admission1','bank', 'adult1']
    n_splits = 5
    random_seed = 42
    np.random.seed(42)
    puritys = [0.90, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.00]
    src_path = os.path.dirname(os.path.realpath('__file__'))
    result_path = os.path.join(src_path, 'FairGBFC/comparative_result/')
    for d in range(len(filenames)):
        dataset_name = filenames[d]
        file_path = result_path + 'GBG/' + str(dataset_name) + '.txt'
        file = open(file_path, mode='a')
        file.write(filenames[d] + '\n\n')
        data_frame, sen_index, cat_indices, cont_indices = load_data(dataset_name)
        args.sen_index = sen_index
        print(data_frame)
        print(args.sen_index)
        data = data_frame.values
        data_temp = []
        data_list = data.tolist()
        data = []
        for data_single in data_list:
            if data_single[1:] not in data_temp:
                data_temp.append(data_single[1:])
                data.append(data_single)
        data = np.array(data)
        numberSample = data.shape[0]

        minMax = MinMaxScaler()
        data = np.hstack((data[:, 0].reshape(numberSample, 1),
                            minMax.fit_transform(data[:, 1:])))
        train_data = data[:, 1:]
        train_target = data[:, 0]

        skf = StratifiedKFold(n_splits, shuffle=True, random_state=42)
        for purity in puritys:
            # ✅ 各指标结果列表（便于计算方差）
            acc_list, f1_list, recall_list = [], [], []
            dp_list, ge_list, eo_list = [], [], []
            gbnum_list, time_list = [], []
        
            count = 0
            for train_index, test_index in skf.split(train_data, train_target):
                count += 1
                train, test = data[train_index], data[test_index]
                X_test = test[:, 1:]
                Y_test = test[:, 0]
                X_train = train[:, 1:]
                S_train = X_train[:, args.sen_index]
                unique_vals = np.sort(np.unique(S_train))
                m = len(unique_vals)
                Y_train = train[:, 0]
                X_test = test[:, 1:]
                S_test = X_test[:, args.sen_index]
                S_test = np.array([np.where(unique_vals == x)[0][0] for x in S_test])
                start = time.time()
                Ball_list = generateGBList(train, purity)
                pred_label, _ = GB_KNN1(X_test, Ball_list)
                end = time.time()

                # 计算指标

                # 保存结果
                acc = accuracy_score(Y_test, pred_label)
                f1 = f1_score(Y_test, pred_label)
                recall = recall_score(Y_test, pred_label)
                dp = demographic_parity_difference(Y_test, pred_label, sensitive_features=S_test)
                eo = equalized_odds_difference(Y_test, pred_label, sensitive_features=S_test)
                ge = generalized_entropy_error(Y_test, pred_label)
                # ✅ 保存单折结果
                acc_list.append(acc)
                f1_list.append(f1)
                recall_list.append(recall)

                dp_list.append(dp)
                ge_list.append(ge)
                eo_list.append(eo)
                gbnum_list.append(len(Ball_list))
                time_list.append(end - start)

            # ✅ 计算均值和方差
            def avg_var_str(name, values):
                return (f'Average {name} of {filenames[d]}: {np.mean(values):.6f}\n'
                        f'Std {name} of {filenames[d]}: {np.std(values):.6f}\n')

            # ✅ 写入结果
            file.write(('purity_threshold is ' + str(purity) + '\n'))
            file.write(avg_var_str('accuracy', acc_list))
            file.write(avg_var_str('f1', f1_list))
            file.write(avg_var_str('recall', recall_list))
            file.write(avg_var_str('dp', dp_list))
            file.write(avg_var_str('eo', eo_list))
            file.write(avg_var_str('ge', ge_list))
            file.write(avg_var_str('gb_num', gbnum_list))
            file.write(avg_var_str('time', time_list))


            print(filenames[d], 'done!')

        file.write('all done!!!!!')
        file.close()


if __name__ == '__main__':
    main()




