import ast
import math
import time
import warnings
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
import os
import pandas as pd
import random
import scipy.io
from aif360.sklearn.metrics import generalized_entropy_error
from fairlearn.metrics import demographic_parity_difference, equalized_odds_difference
import matplotlib.pyplot as plt
from sklearn.cluster import k_means
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score

import numpy as np
import argparse
parser = argparse.ArgumentParser(description='FairGBFC')
parser.add_argument('--sen_index', type=int, default=1, help='index of sensitive attribute')
args = parser.parse_args()



# 1.输入数据data
# 2.打印绘制原始数据
# 3.判断粒球的纯度
# 4.纯度不满足要求，k-means划分粒球
# 5.绘制每个粒球的数据点
# 6.计算粒球均值，得到粒球中心和半径，绘制粒球

# Calculate label and purity of granular-balls
def get_label_and_purity(gb):
    # Calculate the number of data categories
    len_label = np.unique(gb[:, 0], axis=0)
    # print(len_label)

    if len(len_label) == 1:
        purity = 1.0
        label = len_label[0]
    else:
        num = gb.shape[0]
        gb_label_temp = {}
        for label in len_label.tolist():
            # Separate data with different labels
            gb_label_temp[sum(gb[:, 0] == label)] = label
        # print(gb_label_temp)
        # The proportion of the largest category of data in all data
        max_label = max(gb_label_temp.keys())
        purity = max_label / num if num else 1.0
        label = gb_label_temp[max_label]
    # print(label)
    # label, purity
    return label, purity


# Calculate granular-balls center and radius
def calculate_center_and_radius(gb):
    data_no_label = gb[:, 1:]
    # print(data_no_label)
    center = data_no_label.mean(axis=0)
    radius_mean = np.mean((((data_no_label - center) ** 2).sum(axis=1) ** 0.5))
    radius_max = np.max((((data_no_label - center) ** 2).sum(axis=1) ** 0.5))
    return center, radius_mean


# Calculate distance
def calculate_distances(data, p):
    # print(data, p)
    return ((data - p) ** 2).sum(axis=0) ** 0.5


# draw granular-balls
def gb_plot(gb_dict, plt_type=0):
    color = {-1: 'r', 1: 'k', 0: 'b', 3: 'y', 4: 'g', 5: 'c', 6: 'm', 7: 'peru', 8: 'pink', 9: 'gold'}
    plt.figure(figsize=(5, 4))  # width and height of the image
    plt.axis([-1.2, 1.2, -1, 1])
    for key in gb_dict.keys():
        gb = gb_dict[key][0][:, 0:3]
        label, p = get_label_and_purity(gb)
        k = np.unique(gb[:, 0], axis=0)
        center, radius_mean = calculate_center_and_radius(gb)

        if plt_type == 0:
            # plot all points
            for i in k.tolist():
                data0 = gb[gb[:, 0] == i]
                plt.plot(data0[:, 1], data0[:, 2], '.', color=color[i], markersize=5)

        if plt_type == 0 or plt_type == 1:  # draw balls
            theta = np.arange(0, 2 * np.pi, 0.01)
            x = center[0] + radius_mean * np.cos(theta)
            y = center[1] + radius_mean * np.sin(theta)
            plt.plot(x, y, color[label], linewidth=0.8)

        plt.plot(center[0], center[1], 'x' if plt_type == 0 else '.', color=color[label])  # draw centers
    plt.show()


# draw granular-balls
def plot_gb(granular_ball_list, plt_type=0):
    color = {-1: 'r', 1: 'k', 0: 'b'}
    plt.figure(figsize=(5, 4))
    plt.axis([-1.2, 1.2, -1, 1])
    ball_num_str = str(len(granular_ball_list))
    for granular_ball in granular_ball_list:
        label, p = get_label_and_purity(granular_ball)
        center, radius= calculate_center_and_radius(granular_ball)

        if plt_type == 0:
            data0 = granular_ball[granular_ball[:, 0] == 0]
            data1 = granular_ball[granular_ball[:, 0] == 1]
            data2 = granular_ball[granular_ball[:, 0] == -1]
            plt.plot(data0[:, 1], data0[:, 2], '.', color=color[0], markersize=5)
            plt.plot(data1[:, 1], data1[:, 2], '.', color=color[1], markersize=5)
            plt.plot(data2[:, 1], data2[:, 2], '.', color=color[-1], markersize=5)

        if plt_type == 0 or plt_type == 1:
            theta = np.arange(0, 2 * np.pi, 0.01)
            x = center[0] + radius * np.cos(theta)
            y = center[1] + radius * np.sin(theta)
            plt.plot(x, y, color[label], linewidth=0.8)

        plt.plot(center[0], center[1], 'x' if plt_type == 0 else '.', color=color[label])

    plt.show()


def splits(purity, gb_dict):
    gb_dict_new = {}
    while True:
        # Copy a temporary list, and then traverse the value
        if len(gb_dict_new) == 0:
            # initial assignment
            gb_dict_temp = gb_dict.copy()
        else:
            # Subsequent traversal assignment
            gb_dict_temp = gb_dict_new.copy()
        gb_dict_new = {}
        # 记录分割前粒球数量
        ball_number_1 = len(gb_dict_temp)
        # print("ball_number_1:", ball_number_1)
        #遍历处理每个粒球
        for key in gb_dict_temp.keys():
            gb_single = {}
            # 取出单个粒球
            gb_single[key] = gb_dict_temp[key]
            # print("gb_single:", gb_single)
            gb = gb_single[key][0]#gb是某个球所包含的所有点
            # print(len(gb))

            # 计算当前粒球纯度
            p = get_label_and_purity(gb)[1]
            # print(p)
            # # 如果纯度不足且不是单独一个点成为的粒球
            if p < purity and len(gb) != 1:
                # print(gb_single)
                #对这个粒球进行分割
                gb_dict_re = splits_ball(gb_single).copy()
                #将分割结果加入新集合
                gb_dict_new.update(gb_dict_re)
            else:
                gb_dict_new.update(gb_single)
                continue
        # 记录分割后的粒球数量
        ball_number_2 = len(gb_dict_new)
        # print("ball_number_2:", len(gb_dict_new))
        # The number of granular-balls is the same as the number of granular-balls last divided, that is, it will not change
        # 数量不再变化
        if ball_number_1 == ball_number_2:
            break

        # draw granular-balls
        # gb_plot(gb_dict_new)

    # draw granular-balls
    # gb_plot(gb_dict_new)
    return gb_dict_new

# split granular-balls
def splits_ball(gb_dict):
    # {center: [gb, distances]}
    center = []  # old center
    distances_other_class = []  # distance to heterogeneous data points
    balls = []  # The result after clustering
    center_other_class = []
    ball_list = {}  # Returned dictionary result, key: center point, value (ball , distance from data to center)
    distances_other_temp = []

    centers_dict = []  # centers
    gbs_dict = []  # data
    distances_dict = []  # distances

    # Fetch dictionary data, including keys and values
    gb_dict_temp = gb_dict.popitem()
    for center_split in gb_dict_temp[0].split('_'):
        center.append(float(center_split))
    center = np.array(center)
    gb = gb_dict_temp[1][0]  # 获取粒球数据
    distances = gb_dict_temp[1][1]  # 获取到原中心的距离数组
    # print('center:', center)
    # print('gb:', gb)
    # print('distances:', distances)
    centers_dict.append(center)  # 添加原中心到中心列表


    # Take a new center
    # 获取所有唯一标签
    len_label = np.unique(gb[:, 0], axis=0)
    # 如果整个粒球只有一个类别，也要挑一个和原始中心不同的新中心；如果有多个类别，就随机从其它类中挑新的划分中心。
    if len(len_label) > 1:
        gb_class = len(len_label)
    else:
        gb_class = 2
    # Take multiple centers for multiple types of data
    len_label = len_label.tolist()
    for i in range(0, gb_class - 1):
        # print(len_label)
        if len(len_label) < 2:
            #只有一类数据时，移除原中心点，并且随机选择新中心
            gb_temp = np.delete(gb, distances.index(0), axis=0)  # Remove the old center
            ran = random.randint(0, len(gb_temp) - 1)
            center_other_temp = gb_temp[ran]  # Take a new center
            center_other_class.append(center_other_temp)
        else:
            #有多类数据时，原来那个中心保留，并且将原来中心的标签从挑选标签中去除
            #然后获取粒球中不同标签的数据，并随机选择异类中心，添加到新中心列表
            if center[0] in len_label:
                len_label.remove(center[0])
            gb_temp = gb[gb[:, 0] == len_label[i], :]  # Extract heterogeneous data

            # random center of heterogeneity
            ran = random.randint(0, len(gb_temp) - 1)
            center_other_temp = gb_temp[ran]

            # center_other_temp = select_center(gb_temp)
            # print(center_other_temp)
            center_other_class.append(center_other_temp)
            # print(distances.index(max(distances)))
    # print('center_other_class:', center_other_class)
    # join the centers
    centers_dict.extend(center_other_class)
    # print('centers_dict:', centers_dict)

    # Store all data distance to old center
    distances_other_class.append(distances)
    # Calculate the distance to each new center
    #计算粒球中每个点到各新中心的距离
    for center_other in center_other_class:
        balls = []  # The result after clustering
        distances_other = []
        for feature in gb:
            distances_other.append(calculate_distances(feature[1:], center_other[1:]))
        # new centers
        # distances_dict.append(distances_other)
        distances_other_temp.append(distances_other)  # Temporary storage distance to each new center
        # Store all data distance to new center
        distances_other_class.append(distances_other)
    # print('distances_other_class:', len(distances_other_class))

    # The distance from a certain data to the original center and the new center, take the smallest for classification
    # 找到每个点到新中心的最小距离对应的中心，从而完成分割。
    for i in range(len(distances)):
        distances_temp = []
        distances_temp.append(distances[i])
        for distances_other in distances_other_temp:
            distances_temp.append(distances_other[i])
        # print('distances_temp:', distances_temp)
        classification = distances_temp.index(min(distances_temp))  # 0:old center；1,2...：new centers
        balls.append(classification)
    # Clustering situation
    balls_array = np.array(balls)
    # print("Clustering situation：", balls_array)

    # Assign data based on clustering
    for i in range(0, len(centers_dict)):
        gbs_dict.append(gb[balls_array == i, :])
    # print('gbs_dict:', gbs_dict)

    # assign new distance
    i = 0
    for j in range(len(centers_dict)):
        distances_dict.append([])
    # print('distances_dict:', distances_dict)
    for label in balls:
        distances_dict[label].append(distances_other_class[label][i])
        i += 1
    # print('distances_dict:', distances_dict)

    # packed into a dictionary
    for i in range(len(centers_dict)):
        gb_dict_key = str(float(centers_dict[i][0]))
        for j in range(1, len(centers_dict[i])):
            gb_dict_key += '_' + str(float(centers_dict[i][j]))
        gb_dict_value = [gbs_dict[i], distances_dict[i]]  # Pellets + distance to centers
        ball_list[gb_dict_key] = gb_dict_value

    # print('ball_list:', ball_list)
    return ball_list

def nearest_knn(X_test, ball_list):
    predict_label = []
    # 提取所有粒球的中心、半径、标签
    ball_centers = []
    ball_radii = []
    ball_labels = []
    for ball in ball_list:
        center, radius = calculate_center_and_radius(ball)
        label, _ = get_label_and_purity(ball)
        ball_centers.append(center)
        ball_radii.append(radius)
        ball_labels.append(label)

    ball_centers = np.array(ball_centers)
    ball_radii = np.array(ball_radii)
    for row in X_test:
        # 计算当前测试样本到所有粒球中心的欧氏距离，并减去对应半径
        dist = np.linalg.norm(ball_centers - row, axis=1) - ball_radii
        nearest_index = np.argmin(dist)
        predict_label.append(ball_labels[nearest_index])

    return predict_label

def load_data(name):
    src_path = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(src_path, '..', 'Datasets/processed dataset/')
    name = name + '.csv'
    print(dataset_path)
    # 设置敏感属性索引
    if name == 'credit_approval.csv':  # Gender
        sen_index = 1
        #在属性中的列下标
        cat_indices = [0, 3, 4, 5, 6, 8, 9, 11, 12]
        cont_indices = [1, 2, 7, 10, 13, 14]
    elif name == 'adult.csv':  # sex
        sen_index = 8
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'adult1.csv':  # race
        sen_index = 7# race
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'credit_default.csv':  # SEX
        sen_index = 2#这是包括了标签的
        cat_indices = [1, 2, 3, 5, 6, 7, 8, 9, 10]
        cont_indices = [0, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    elif name == 'law_admission.csv':  # gender
        sen_index = 6
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'law_admission1.csv':  # race
        sen_index = 7
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'german.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
        cont_indices = [3, 4, 5]
    
    elif name == 'por.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]

    # 加载数据
    if name == 'law_admission1.csv':
        name = 'law_admission.csv'
    if name == 'credit_approval1.csv':
        name = 'credit_approval.csv'
    if name == 'adult1.csv':
        name = 'adult.csv'
    data_path = os.path.join(dataset_path, name)
    if not os.path.exists(data_path):
         print(f"Error: File not found at {data_path}")
         return None, None, None, None
    data = pd.read_csv(data_path)

    label_col = data.iloc[:, 0]
    sensi_col = data.iloc[:, sen_index]

    # 输出统计信息
    print(f"\nDataset: {name}")
    
    # 敏感属性比例
    sen_col_ratio = sensi_col.value_counts(normalize=True)
    print(f"\n✅ 敏感属性列的值比例：")
    print(sen_col_ratio)

    # 标签分布
    col0_ratio = label_col.value_counts(normalize=True)
    print("\n✅ 标签的值比例：")
    print(col0_ratio)

    # DP值计算
    pos_label = 1
    groups = sensi_col.unique()
    if len(groups) >= 2:
        g1, g2 = groups[0], groups[1]
        g1_data = data[sensi_col == g1]
        g2_data = data[sensi_col == g2]

        g1_pos_rate = (g1_data.iloc[:, 0] == pos_label).sum() / len(g1_data)
        g2_pos_rate = (g2_data.iloc[:, 0] == pos_label).sum() / len(g2_data)
        dp = abs(g1_pos_rate - g2_pos_rate)
        print(f"\n✅ DP值（|P({g1}, 正例) - P({g2}, 正例)|）：{dp:.6f}")

        # 输出更详细的统计信息
        g1_total = len(g1_data)
        g1_pos = (g1_data.iloc[:, 0] == pos_label).sum()
        g2_total = len(g2_data)
        g2_pos = (g2_data.iloc[:, 0] == pos_label).sum()
        pos_total = (label_col == pos_label).sum()
        neg_total = (label_col != pos_label).sum()

        print(f"\n🔍 样本详细统计：")
        print(f" 敏感属性 - {g1} 样本总数：{g1_total}")
        print(f" 敏感属性 - {g1} 中正样本数：{g1_pos}")
        print(f" 敏感属性 - {g2} 样本总数：{g2_total}")
        print(f" 敏感属性 - {g2} 中正样本数：{g2_pos}")
        print(f"  - 正类样本总数：{pos_total}")
        print(f"  - 负类样本总数：{neg_total}")
    
    # 返回 Data 以及 元数据
    return data, sen_index - 1, cat_indices, cont_indices

def main():
    warnings.filterwarnings("ignore")  # ignore warning
    # filenames = ['credit_approval', 'adult', 'bail', 'bank', 'credit_default', 'law_admission','recruitment', 'compas', 'thyroid','german', 'por', 'math']
    filenames = ['law_admission1','bank', 'adult1']
    n_splits = 5
    random_seed = 42
    np.random.seed(42)
    puritys = [0.90, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.00]
    src_path = os.path.dirname(os.path.realpath('__file__'))
    result_path = os.path.join(src_path, 'FairGBFC/comparative_result/')
    for d in range(len(filenames)):
        dataset_name = filenames[d]
        file_path = result_path + 'GBACC/' + str(dataset_name) + '.txt'
        file = open(file_path, mode='a')
        file.write(filenames[d] + '\n\n')
        data_frame, sen_index, cat_indices, cont_indices = load_data(dataset_name)
        args.sen_index = sen_index
        print(data_frame)
        print(args.sen_index)
        data = data_frame.values
        data_temp = []
        data_list = data.tolist()
        data = []
        for data_single in data_list:
            if data_single[1:] not in data_temp:
                data_temp.append(data_single[1:])
                data.append(data_single)
        data = np.array(data)
        numberSample = data.shape[0]

        minMax = MinMaxScaler()
        data = np.hstack((data[:, 0].reshape(numberSample, 1),
                            minMax.fit_transform(data[:, 1:])))
        train_data = data[:, 1:]
        train_target = data[:, 0]

        skf = StratifiedKFold(n_splits, shuffle=True, random_state=42)
        for purity in puritys:
            # ✅ 各指标结果列表（便于计算方差）
            acc_list, f1_list, recall_list = [], [], []
            dp_list, ge_list, eo_list = [], [], []
            gbnum_list, time_list = [], []
        
            count = 0
            for train_index, test_index in skf.split(train_data, train_target):
                count += 1
                train, test = data[train_index], data[test_index]
                X_test = test[:, 1:]
                Y_test = test[:, 0]
                X_train = train[:, 1:]
                S_train = X_train[:, args.sen_index]
                unique_vals = np.sort(np.unique(S_train))
                m = len(unique_vals)
                Y_train = train[:, 0]
                X_test = test[:, 1:]
                S_test = X_test[:, args.sen_index]
                S_test = np.array([np.where(unique_vals == x)[0][0] for x in S_test])
                ball_list = []

                # record start time
                start = time.time()

                # initialize random center
                center_init = train[random.randint(0, len(train) - 1), :]

                # compute distances
                distance_init = []
                for feature in train:
                    distance_init.append(calculate_distances(feature[1:], center_init[1:]))

                gb_dict = {}
                gb_dict_key = str(center_init.tolist()[0])
                for j in range(1, len(center_init)):
                    gb_dict_key += '_' + str(center_init.tolist()[j])
                gb_dict_value = [train, distance_init]
                gb_dict[gb_dict_key] = gb_dict_value

                # perform splitting
                gb_dict = splits(purity=purity, gb_dict=gb_dict)

                # extract ball centers
                k_centers = []
                splits_k = len(gb_dict)
                for key in gb_dict.keys():
                    k_center = []
                    for k in key.split('_'):
                        k_center.append(float(k))
                    k_centers.append(k_center[1:])

                # perform k-means based refinement
                label_cluster = k_means(X=train[:, 1:], n_clusters=splits_k, n_init=1,
                                        init=np.array(k_centers), random_state=5)[1]
                for single_label in range(splits_k):
                    ball_list.append(train[label_cluster == single_label, :])

                # predict using nearest GB-KNN
                pred_label = nearest_knn(X_test, ball_list)

                # record end time
                end = time.time()

                # compute metrics
                acc = accuracy_score(Y_test, pred_label)
                f1 = f1_score(Y_test, pred_label)
                recall = recall_score(Y_test, pred_label)
                dp = demographic_parity_difference(Y_test, pred_label, sensitive_features=S_test)
                eo = equalized_odds_difference(Y_test, pred_label, sensitive_features=S_test)
                ge = generalized_entropy_error(Y_test, pred_label)
                # ✅ 保存单折结果
                acc_list.append(acc)
                f1_list.append(f1)
                recall_list.append(recall)

                dp_list.append(dp)
                ge_list.append(ge)
                eo_list.append(eo)
                gbnum_list.append(len(gb_dict))
                time_list.append(end - start)

            # ✅ 计算均值和方差
            def avg_var_str(name, values):
                return (f'Average {name} of {filenames[d]}: {np.mean(values):.6f}\n'
                        f'Std {name} of {filenames[d]}: {np.std(values):.6f}\n')

            # ✅ 写入结果
            file.write(('purity_threshold is ' + str(purity) + '\n'))
            file.write(avg_var_str('accuracy', acc_list))
            file.write(avg_var_str('f1', f1_list))
            file.write(avg_var_str('recall', recall_list))
            file.write(avg_var_str('dp', dp_list))
            file.write(avg_var_str('eo', eo_list))
            file.write(avg_var_str('ge', ge_list))
            file.write(avg_var_str('gb_num', gbnum_list))
            file.write(avg_var_str('time', time_list))
            file.write('\n')

            print(filenames[d], 'done!')

        file.write('all done!!!!!')
        file.close()



if __name__ == '__main__':
    main()
