import copy
import csv
import time
import pandas as pd
import numpy as np
import os
import random
from matplotlib import pyplot as plt
import warnings
import sys
from sklearn.neighbors import NearestNeighbors
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
from fairlearn.metrics import demographic_parity_difference, equalized_odds_difference
from GAFC import fair_kmeans
from aif360.sklearn.metrics import generalized_entropy_error
import argparse
parser = argparse.ArgumentParser(description='FairGBFC')
parser.add_argument('--sen_index', type=int, default=1, help='index of sensitive attribute')
parser.add_argument('--base_p1', type=int, default=1000, help='GAFC')#(200)
parser.add_argument('--min_num', type=int, default=20, help='RBF')
parser.add_argument('--purity_threshold', type=float, default=0.90, help='purity_threshold')
parser.add_argument('--mixup_k', type=int, default=10, help='mixup_k')
args = parser.parse_args()
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))

sys.path.append(project_root)
warnings.filterwarnings("ignore")


class GranularBall:
    def __init__(self, data) -> None:
        self.data = data
        self.center, self.radius, self.dis = calCenterRadius(data[:, 1:])
        self.num = len(data)
        self.label, self.purity, self.majority_num = calLabelPurity(data)
        self.sen_ind = indicator_matrix(data[:, 1:], args.sen_index)
        group_counts = np.sum(self.sen_ind, axis=0)
        active_counts = group_counts[group_counts > 0]
        if len(active_counts) > 0:
            m_j = len(active_counts)
            n_min = np.min(active_counts)
            self.epsilon = self.num - m_j * n_min
        else:
            self.epsilon = 0





def indicator_matrix(data, sen_index):
    """
    生成敏感属性的 One-Hot 矩阵
    data: (n_samples, n_features_with_label) 或纯特征
    sen_index: 敏感属性所在的列索引
    """
    sen_value = data[:,sen_index]
    uq = np.unique(sen_value)
    c = len(uq)
    sen_ind = np.zeros((len(sen_value),c))
    for i, val in enumerate(uq):
        sen_ind[sen_value == val, i] = 1
    return sen_ind

def load_data(name):
    src_path = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(src_path, 'Datasets/processed dataset/')
    name = name + '.csv'
    if name == 'credit_approval.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 3, 4, 5, 6, 8, 9, 11, 12]
        cont_indices = [1, 2, 7, 10, 13, 14]
    elif name == 'adult.csv':  # sex
        sen_index = 8# sex
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'adult1.csv':  # race
        sen_index = 7# race
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'bail.csv':  # WHITE
        sen_index = 1
        cat_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        cont_indices = [10, 11, 12, 13, 14]
    elif name == 'bank.csv':  # age
        sen_index = 1
        cat_indices = [0, 1, 2, 3, 4, 6, 7, 8, 13]
        cont_indices = [5, 9, 10, 11, 12]
    elif name == 'compas.csv':  # race
        sen_index = 3
        cat_indices = [1, 2, 3]
        cont_indices = [0, 4, 5, 6]
    elif name == 'credit_default.csv':  # SEX
        sen_index = 2
        cat_indices = [1, 2, 3, 5, 6, 7, 8, 9, 10]
        cont_indices = [0, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    elif name == 'law_admission.csv':  # gender
        sen_index = 6
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'law_admission1.csv':  # race
        sen_index = 7
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'recruitment.csv':  # Gender
        sen_index = 2
        cat_indices = [1, 2, 9]
        cont_indices = [0, 3, 4, 5, 6, 7, 8]
    elif name == 'thyroid.csv':  # Gender
        sen_index = 2
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
        cont_indices = [0]
    elif name == 'german.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
        cont_indices = [3, 4, 5]
    elif name == 'math.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]
    
    elif name == 'por.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]

    elif name == 'meps.csv':  # RACE
        sen_index = 2
        cat_indices = [1] + list(range(5, 138))
        cont_indices = [0, 2, 3, 4]

    elif name == 'diabetic.csv':  # race
        sen_index = 1
        cat_indices = list(range(0, 5)) + list(range(11, 23))
        cont_indices = [6, 7, 8, 9, 10]
    if name == 'law_admission1.csv':
        name = 'law_admission.csv'
    if name == 'credit_approval1.csv':
        name = 'credit_approval.csv'
    if name == 'adult1.csv':
        name = 'adult.csv'
    data_path = os.path.join(dataset_path, name)
    if not os.path.exists(data_path):
         print(f"Error: File not found at {data_path}")
         return None, None, None, None
    data = pd.read_csv(data_path)

    label_col = data.iloc[:, 0]
    sensi_col = data.iloc[:, sen_index]
    print(f"\nDataset: {name}")
    
    sen_col_ratio = sensi_col.value_counts(normalize=True)
    print(f"\n✅ 敏感属性列的值比例：")
    print(sen_col_ratio)

    col0_ratio = label_col.value_counts(normalize=True)
    print("\n✅ 标签的值比例：")
    print(col0_ratio)

    pos_label = 1
    groups = sensi_col.unique()
    if len(groups) >= 2:
        g1, g2 = groups[0], groups[1]
        g1_data = data[sensi_col == g1]
        g2_data = data[sensi_col == g2]

        g1_pos_rate = (g1_data.iloc[:, 0] == pos_label).sum() / len(g1_data)
        g2_pos_rate = (g2_data.iloc[:, 0] == pos_label).sum() / len(g2_data)
        dp = abs(g1_pos_rate - g2_pos_rate)
        print(f"\n✅ DP值（|P({g1}, 正例) - P({g2}, 正例)|）：{dp:.6f}")

        g1_total = len(g1_data)
        g1_pos = (g1_data.iloc[:, 0] == pos_label).sum()
        g2_total = len(g2_data)
        g2_pos = (g2_data.iloc[:, 0] == pos_label).sum()
        pos_total = (label_col == pos_label).sum()
        neg_total = (label_col != pos_label).sum()

        print(f"\n🔍 样本详细统计：")
        print(f" 敏感属性 - {g1} 样本总数：{g1_total}")
        print(f" 敏感属性 - {g1} 中正样本数：{g1_pos}")
        print(f" 敏感属性 - {g2} 样本总数：{g2_total}")
        print(f" 敏感属性 - {g2} 中正样本数：{g2_pos}")
        print(f"  - 正类样本总数：{pos_total}")
        print(f"  - 负类样本总数：{neg_total}")
    
    return data, sen_index - 1, cat_indices, cont_indices


# 计算粒球的标签和纯度
def calLabelPurity(data):
    if len(data) > 0:
        count = Counter(data[:, 0])
        label = max(count, key=count.get)
        purity = count[label] / len(data)
        return label, purity, count[label]
    else:
        return 0, 0, 0
    

#计算粒球的中心和半径
def calCenterRadius(data):
    if len(data) == 0: return np.zeros(data.shape[1]), 0, []
    center = np.mean(data, axis=0)
    
    dis_matrix = calculateDist(data, center)
    
    dis = dis_matrix.flatten()
    
    radius = np.mean(dis)
    return center, radius, dis

# 计算样本之间的欧式距离
def calculateDist(A, B, flag=0):
    if flag == 0:
        if A.ndim == 1: A = A[np.newaxis, :]
        if B.ndim == 1: B = B[np.newaxis, :] 

        return np.sqrt(np.sum((A[:, np.newaxis, :] - B[np.newaxis, :, :]) ** 2, axis=2))
    else:
        return np.sqrt(np.sum((A - B) ** 2))
    
def neighborhood_mixup(X, Y, sen_index, cat_indices, cont_indices, k=5):
    """
    Data Augmentation via Neighborhood Mixup (Mixed Types & Stratified)
    
    参数:
    X: 特征矩阵 (numpy array)
    Y: 标签向量 (numpy array)
    sen_index: 敏感属性在原始CSV中的索引 (int)
    cat_indices: 分类属性在原始CSV中的索引列表 (list)
    cont_indices: 连续属性在原始CSV中的索引列表 (list)
    k: 近邻数量
    """
    X_sen_idx = int(sen_index)
    # 映射分类和连续属性索引到 X 的空间
    X_cat_idxs = [i for i in cat_indices if i > 0]
    X_cont_idxs = [i for i in cont_indices if i > 0]

    sensitive_attr = X[:, X_sen_idx]
    groups = np.unique(sensitive_attr)

    # ==========================================
    # 1. [新增] 打印增强前的分布
    # ==========================================
    total_samples = len(X)
    group_counts = {g: np.sum(sensitive_attr == g) for g in groups}
    max_count = max(group_counts.values())
    
    print("-" * 50)
    print(f"【增强前】 数据总数: {total_samples}")
    print(f"{'敏感组':<10} {'数量':<10} {'比例':<10}")
    for g in groups:
        count = group_counts[g]
        ratio = count / total_samples
        print(f"{int(g):<10} {count:<10} {ratio:.2%}")
    print(f"目标数量 (Max Count): {max_count}")
    print("-" * 50)
    # ==========================================
    
    augmented_data = []
    for g in groups:
        g_indices = np.where(sensitive_attr == g)[0]
        current_count = len(g_indices)
        if current_count < max_count:
            n_augment_total = max_count - current_count
            
            g_pos_indices = g_indices[Y[g_indices] == 1]
            g_neg_indices = g_indices[Y[g_indices] == 0]

            pos_rate = len(g_pos_indices) / current_count if current_count > 0 else 0
            n_augment_pos = int(np.round(n_augment_total * pos_rate))
            n_augment_neg = n_augment_total - n_augment_pos

            def augment_subset(subset_indices, n_needed):
                if n_needed <= 0 or len(subset_indices) == 0:
                    return
                
                subset_X = X[subset_indices]
                curr_k = min(k, len(subset_X))
                if curr_k < 1: return 

                nbrs = NearestNeighbors(n_neighbors=curr_k, algorithm='auto').fit(subset_X)
                _, all_neighbor_idxs = nbrs.kneighbors(subset_X)

                for _ in range(n_needed):
                    seed_local_idx = np.random.randint(len(subset_X))
                    neighbor_candidates = all_neighbor_idxs[seed_local_idx]
                    if len(neighbor_candidates) > 1:
                        target_local_idx = np.random.choice(neighbor_candidates[1:])
                    else:
                        target_local_idx = neighbor_candidates[0]

                    seed_sample = subset_X[seed_local_idx]
                    target_sample = subset_X[target_local_idx]

                    seed_global_idx = subset_indices[seed_local_idx]
                    seed_y = Y[seed_global_idx]


                    lam = np.random.uniform(0, 1) 
                    
                    new_x = np.zeros_like(seed_sample, dtype=float)

                    if len(X_cont_idxs) > 0:
                        new_x[X_cont_idxs] = lam * seed_sample[X_cont_idxs] + (1 - lam) * target_sample[X_cont_idxs]
                    
                    if len(X_cat_idxs) > 0:
                        rand_mask = np.random.rand(len(X_cat_idxs)) < lam
                        seed_cats = seed_sample[X_cat_idxs]
                        target_cats = target_sample[X_cat_idxs]
                        new_x[X_cat_idxs] = np.where(rand_mask, seed_cats, target_cats)
                    
                    new_x[X_sen_idx] = g
                    
                    augmented_data.append(np.hstack(([seed_y], new_x)))

            augment_subset(g_pos_indices, n_augment_pos)
            augment_subset(g_neg_indices, n_augment_neg)

    original_data = np.column_stack((Y, X))
    if len(augmented_data) > 0:
        final_data = np.vstack((original_data, np.array(augmented_data)))
    else:
        final_data = original_data

    # ==========================================
    # 2. [新增] 打印增强后的分布
    # ==========================================
    final_total = len(final_data)
    # 注意：final_data 第0列是Y，所以敏感属性索引是 sen_index + 1
    final_sen_col = final_data[:, int(sen_index) + 1]
    
    print(f"【增强后】 数据总数: {final_total} (新增: {final_total - total_samples})")
    print(f"{'敏感组':<10} {'数量':<10} {'比例':<10}")
    for g in groups:
        c = np.sum(final_sen_col == g)
        r = c / final_total
        print(f"{int(g):<10} {c:<10} {r:.2%}")
    print("-" * 50)
    # ==========================================

    return final_data



def splits_ball(gb):
    ball_list = []
    data = gb.data
    F = gb.sen_ind
    gb_class = 2
    Initial_centers = [gb.center]
    centers = Initial_centers.copy()
    available_idx = np.where((gb.data[:, 0] != gb.label))[0]
    chosen_idx = np.random.choice(available_idx)
    center_other = gb.data[chosen_idx, 1:]
    centers.append(center_other)


    label, _, _ = fair_kmeans(data, gb_class, centers, F, args.base_p1, args.base_p2)

    unique_labels = np.unique(label)
    if len(unique_labels) < 2: return None 

    for l in unique_labels:
        data_l = data[label == l]
        if len(data_l) > 0:
            gb_l = GranularBall(data_l)
            ball_list.append(gb_l)
    return ball_list
        

def splits(gb_list):
    """
    优化的递归分裂流程 (Work Queue 模式)
    1. 满足条件的球直接移入 final_gb_list，不再参与循环。
    2. 只有不满足条件且可分的球才留在 active_granules 中继续处理。
    """
    final_gb_list = []
    active_granules = gb_list.copy()

    while len(active_granules) > 0:
        gb = active_granules.pop(0)

        if gb.purity >= args.purity_threshold or len(gb.data) <= args.min_num:
            final_gb_list.append(gb)
            continue

        split_result = splits_ball(gb)
        if split_result is None:
            final_gb_list.append(gb)
        else:
    
            for child in split_result:
                if child.purity >= args.purity_threshold or len(child.data) <= args.min_num:
                    final_gb_list.append(child)
                else:
                    active_granules.append(child)

    pruned_gb_list = []
    for gb in final_gb_list:

        if gb.purity >= args.purity_threshold and gb.num >= args.min_num:
            pruned_gb_list.append(gb)
    

    return pruned_gb_list



def FairGBFC_predict(X_test, gb_list):
    preds = []
    gb_centers = np.array([gb.center for gb in gb_list])
    gb_radii = np.array([gb.radius for gb in gb_list])
    gb_labels = np.array([gb.label for gb in gb_list])
    gb_purities = np.array([gb.purity for gb in gb_list])
    gb_epsilons = np.array([gb.epsilon for gb in gb_list])
    raw_dists = calculateDist(X_test, gb_centers) 
    alg_dists = np.maximum(raw_dists - gb_radii[np.newaxis, :], 0) # (N, M)
    preds = np.zeros(len(X_test), dtype=int)
    is_overlap = np.any(alg_dists == 0, axis=1) # (N,) Boolean mask

    if np.any(is_overlap):
       
        global_scores = gb_purities * np.exp(-gb_epsilons)
        
       
        candidate_scores = np.where(alg_dists == 0, global_scores[np.newaxis, :], -np.inf)
        
        
        best_indices_overlap = np.argmax(candidate_scores[is_overlap], axis=1)
        preds[is_overlap] = gb_labels[best_indices_overlap]

    if np.any(~is_overlap):

        best_indices_gap = np.argmin(alg_dists[~is_overlap], axis=1)
        preds[~is_overlap] = gb_labels[best_indices_gap]
            
    return preds

def prepare_augmented_folds(dataset_name, data_norm, sen_index, cat_indices, cont_indices, n_splits, random_seed, cache_dir):
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
        
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
    
    train_data_full = data_norm[:, 1:]
    train_target_full = data_norm[:, 0]
    
    print(f"🔄 Starting Data Augmentation and Caching for {dataset_name}...")
    
    for fold_idx, (train_index, test_index) in enumerate(skf.split(train_data_full, train_target_full)):
        train_file = os.path.join(cache_dir, f'{dataset_name}_fold{fold_idx}_train_aug.npy')
        test_file = os.path.join(cache_dir, f'{dataset_name}_fold{fold_idx}_test.npy')
        
        # if os.path.exists(train_file) and os.path.exists(test_file):
        #     print(f"  Fold {fold_idx}: Cache found. Skipping generation.")
        #     continue
        
        print(f"  ⚙️ Fold {fold_idx}: Cache missing. Generating augmented data...")
        train_raw = data_norm[train_index]
        test_raw = data_norm[test_index]
        
        train_augmented = neighborhood_mixup(
            train_raw[:, 1:], 
            train_raw[:, 0], 
            sen_index, 
            cat_indices, 
            cont_indices, 
            k=args.mixup_k
        )
        
        np.save(train_file, train_augmented)
        np.save(test_file, test_raw)
        
        print(f"  Fold {fold_idx}: Saved. Train ({len(train_raw)}->{len(train_augmented)}), Test ({len(test_raw)})")
        
    print("✅ Data Preparation Complete.\n")





def main():
    n_splits = 5  
    np.set_printoptions(suppress=True)
    filenames = ['credit_approval', 'adult', 'credit_default', 'law_admission','german', 'por']
    random_seed = 42
    np.random.seed(42)
    src_path = os.path.dirname(os.path.realpath('__file__'))
    result_path = os.path.join(src_path, 'FairGBFC/result/')
    cache_dir = os.path.join(src_path, 'FairGBFC/cache_aug_data/')
    if not os.path.exists(result_path): 
        os.makedirs(result_path)
    for d in range(len(filenames)):
        dataset_name = filenames[d]
        file_path = result_path + 'FairGBFC_' + str(dataset_name) + '.txt'
        file = open(file_path, mode='a')
        file.write(filenames[d] + '\n\n')
        data_frame, sen_index, cat_indices, cont_indices = load_data(dataset_name)
        if data_frame is None: 
            continue
        args.sen_index = sen_index
        all_data = data_frame.values
        data_temp = []
        data_list = all_data.tolist()
        data = []
        for data_single in data_list:
            if data_single[1:] not in data_temp:
                data_temp.append(data_single[1:])
                data.append(data_single)
        data = np.array(data)
        numberSample = data.shape[0]
        minMax = MinMaxScaler()
        data_norm = np.hstack((data[:, 0].reshape(numberSample, 1),
                                minMax.fit_transform(data[:, 1:])))

        prepare_augmented_folds(
            dataset_name, data_norm, sen_index, cat_indices, cont_indices, 
            n_splits, random_seed, cache_dir
        )


        p1_values = np.concatenate((np.arange(0.1, 1, 0.1),np.arange(1, 100, 1)))
        purity_thresholds = [0.90, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]
        for p1 in p1_values:
            for p2 in [0]:
                args.base_p1 = p1

                for purity_threshold in purity_thresholds:
                    args.purity_threshold = purity_threshold
                    # ✅ 新增：保存每折结果的列表
                    acc_list, f1_list, recall_list = [], [], []
                    dp_list, ge_list, eo_list = [], [], []
                    gbnum_list, time_list = [], []

                    # 遍历每一折，直接从文件读取
                    for fold_idx in range(n_splits):
                        # 构建文件名
                        train_file = os.path.join(cache_dir, f'{dataset_name}_fold{fold_idx}_train_aug.npy')
                        test_file = os.path.join(cache_dir, f'{dataset_name}_fold{fold_idx}_test.npy')
                        # === 直接读取，不再重复计算 ===
                        train = np.load(train_file)
                        test = np.load(test_file)
                        X_train = train[:, 1:]
                        S_train = X_train[:, args.sen_index]
                        unique_vals = np.sort(np.unique(S_train))
                        m = len(unique_vals)
                        Y_train = train[:, 0]
                        X_test = test[:, 1:]
                        S_test = X_test[:, args.sen_index]
                        S_test = np.array([np.where(unique_vals == x)[0][0] for x in S_test])
                        Y_test = test[:, 0]
                        args.min_num = int(m)###敏感属性值
                        start = time.time()
                        gb_init = [GranularBall(train)]
                        gb_list = splits(gb_list=gb_init)
                        # --- 推理过程 ---
                        pred_label = FairGBFC_predict(X_test, gb_list)
                        end = time.time()
                        # --- 评估 ---
                        acc = accuracy_score(Y_test, pred_label)
                        f1 = f1_score(Y_test, pred_label)
                        recall = recall_score(Y_test, pred_label)
                        dp = demographic_parity_difference(Y_test, pred_label, sensitive_features=S_test)
                        eo = equalized_odds_difference(Y_test, pred_label, sensitive_features=S_test)
                        ge = generalized_entropy_error(Y_test, pred_label)
                        # ✅ 保存单折结果
                        acc_list.append(acc)
                        f1_list.append(f1)
                        recall_list.append(recall)
        
                        dp_list.append(dp)
                        ge_list.append(ge)
                        eo_list.append(eo)
                        gbnum_list.append(len(gb_list))
                        time_list.append(end - start)

                    # ✅ 计算均值和方差
                    def avg_var_str(name, values):
                        return (f'Average {name} of {filenames[d]}: {np.mean(values):.6f}\n'
                                f'Std {name} of {filenames[d]}: {np.std(values):.6f}\n')

                    # ✅ 写入文件
                    file.write(('p1 is ' + str(args.base_p1) + '\n'))
                    # file.write(('p2 is ' + str(args.base_p2) + '\n'))
                    file.write(('purity_threshold is ' + str(args.purity_threshold) + '\n'))
                    file.write(avg_var_str('accuracy', acc_list))
                    file.write(avg_var_str('f1', f1_list))
                    file.write(avg_var_str('recall', recall_list))
                    file.write(avg_var_str('dp', dp_list))
                    file.write(avg_var_str('eo', eo_list))
                    file.write(avg_var_str('ge', ge_list))
                    file.write(avg_var_str('gb_num', gbnum_list))
                    file.write(avg_var_str('time', time_list))

                    print(filenames[d], 'done!')
    file.write('all done!!!!!')
    file.close()




if __name__ == '__main__':
    main()



        



            


