import pandas as pd
import numpy as np
import os
import random
import warnings
import sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
from aif360.sklearn.metrics import generalized_entropy_error
from fairlearn.metrics import demographic_parity_difference, equalized_odds_difference
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
from sklearn.neural_network import MLPClassifier
import argparse
import postprocess
import time
parser = argparse.ArgumentParser(description='FairGBFC')
parser.add_argument('--sen_index', type=int, default=1, help='index of sensitive attribute')
args = parser.parse_args()
warnings.filterwarnings("ignore")


def load_data(name):
    src_path = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(src_path, '../../Datasets/processed dataset/')
    name = name + '.csv'
    print(dataset_path)
    if name == 'credit_approval.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 3, 4, 5, 6, 8, 9, 11, 12]
        cont_indices = [1, 2, 7, 10, 13, 14]
    elif name == 'adult.csv':  # sex
        sen_index = 8
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'adult1.csv':  # race
        sen_index = 7# race
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'credit_default.csv':  # SEX
        sen_index = 2
        cat_indices = [1, 2, 3, 5, 6, 7, 8, 9, 10]
        cont_indices = [0, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    elif name == 'law_admission.csv':  # gender
        sen_index = 6
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'law_admission1.csv':  # race
        sen_index = 7
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'german.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
        cont_indices = [3, 4, 5]
    
    elif name == 'por.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]


    if name == 'law_admission1.csv':
        name = 'law_admission.csv'
    if name == 'credit_approval1.csv':
        name = 'credit_approval.csv'
    if name == 'adult1.csv':
        name = 'adult.csv'
    data_path = os.path.join(dataset_path, name)
    if not os.path.exists(data_path):
         print(f"Error: File not found at {data_path}")
         return None, None, None, None
    data = pd.read_csv(data_path)

    label_col = data.iloc[:, 0]
    sensi_col = data.iloc[:, sen_index]

    print(f"\nDataset: {name}")
    
    sen_col_ratio = sensi_col.value_counts(normalize=True)
    print(f"\n✅ 敏感属性列的值比例：")
    print(sen_col_ratio)

    col0_ratio = label_col.value_counts(normalize=True)
    print("\n✅ 标签的值比例：")
    print(col0_ratio)

    pos_label = 1
    groups = sensi_col.unique()
    if len(groups) >= 2:
        g1, g2 = groups[0], groups[1]
        g1_data = data[sensi_col == g1]
        g2_data = data[sensi_col == g2]

        g1_pos_rate = (g1_data.iloc[:, 0] == pos_label).sum() / len(g1_data)
        g2_pos_rate = (g2_data.iloc[:, 0] == pos_label).sum() / len(g2_data)
        dp = abs(g1_pos_rate - g2_pos_rate)
        print(f"\n✅ DP值（|P({g1}, 正例) - P({g2}, 正例)|）：{dp:.6f}")

        g1_total = len(g1_data)
        g1_pos = (g1_data.iloc[:, 0] == pos_label).sum()
        g2_total = len(g2_data)
        g2_pos = (g2_data.iloc[:, 0] == pos_label).sum()
        pos_total = (label_col == pos_label).sum()
        neg_total = (label_col != pos_label).sum()

        print(f"\n🔍 样本详细统计：")
        print(f" 敏感属性 - {g1} 样本总数：{g1_total}")
        print(f" 敏感属性 - {g1} 中正样本数：{g1_pos}")
        print(f" 敏感属性 - {g2} 样本总数：{g2_total}")
        print(f" 敏感属性 - {g2} 中正样本数：{g2_pos}")
        print(f"  - 正类样本总数：{pos_total}")
        print(f"  - 负类样本总数：{neg_total}")
    
    return data, sen_index - 1, cat_indices, cont_indices

def main():
    warnings.filterwarnings("ignore")  # ignore warning
    filenames = ['credit_approval', 'adult', 'credit_default', 'law_admission','german', 'por']
    n_splits = 5
    random_seed = 42
    np.random.seed(42)
    src_path = os.path.dirname(os.path.realpath('__file__'))
    result_path = os.path.join(src_path, 'FairGBFC/comparative_result/')
    alphas = [0.001, 0.005, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 ,1.0]


    for d in range(len(filenames)):
        dataset_name = filenames[d]
        file_path = result_path + '/Fair/Post/' + str(dataset_name) + '.txt'
        file = open(file_path, mode='a')
        file.write(filenames[d] + '\n\n')
        data_frame, sen_index, cat_indices, cont_indices = load_data(dataset_name)
        args.sen_index = sen_index
        print(data_frame)
        print(args.sen_index)
        data = data_frame.values
        data_temp = []
        data_list = data.tolist()
        data = []
        for data_single in data_list:
            if data_single[1:] not in data_temp:
                data_temp.append(data_single[1:])
                data.append(data_single)
        data = np.array(data)
        numberSample = data.shape[0]

        minMax = MinMaxScaler()
        data = np.hstack((data[:, 0].reshape(numberSample, 1),
                            minMax.fit_transform(data[:, 1:])))
        train_data = data[:, 1:]
        train_target = data[:, 0]

        skf = StratifiedKFold(n_splits, shuffle=True, random_state=42)
        for alpha in alphas:
            args.post_alpha = alpha
            acc_list, f1_list, recall_list = [], [], []
            dp_list, ge_list, eo_list = [], [], []

            for train_index, test_index in skf.split(train_data, train_target):
                train, test = data[train_index], data[test_index]
                X_test = test[:, 1:]
                Y_test = test[:, 0]
                X_train = train[:, 1:]
                y_train = train[:, 0]
                s_train = X_train[:, args.sen_index]
                s_test = X_test[:, args.sen_index]
                unique_vals = np.sort(np.unique(s_train))
                m = len(unique_vals)
                s_train = np.array([np.where(unique_vals == x)[0][0] for x in s_train])
                s_test = np.array([np.where(unique_vals == x)[0][0] for x in s_test])
                (inputs_pretrain, inputs_postproc, labels_pretrain, labels_postproc,
                    groups_pretrain, groups_postproc) = sklearn.model_selection.train_test_split(
                    X_train,
                    y_train,
                    s_train,
                    test_size=0.5,
                )

                mlp = MLPClassifier(
                    hidden_layer_sizes=(128, 64, 32, 16),
                    activation='relu',
                    solver='adam',
                    alpha=0.0001,
                    batch_size='auto',
                    learning_rate='constant',
                    learning_rate_init=0.001,
                    max_iter=500,
                    early_stopping=True,
                    validation_fraction=0.1,
                    n_iter_no_change=10,
                    random_state=42
                )
                mlp.fit(inputs_pretrain, labels_pretrain)

                predict_y = lambda x: mlp.predict_proba(x).reshape(-1, 2)

                predict_a = lambda x: np.eye(len(unique_vals))[np.searchsorted(unique_vals, x[:, args.sen_index])]

                postprocessor = postprocess.PostProcessor(
                    2,
                    m,
                    pred_y_fn=predict_y,
                    pred_a_fn=predict_a,
                    criterion='sp',
                    alpha=args.post_alpha,
                )
                postprocessor.fit(inputs_postproc, solver=None)

                predict_label = postprocessor.predict(X_test)

                # compute metrics
                acc = accuracy_score(Y_test, predict_label)
                f1 = f1_score(Y_test, predict_label)
                recall = recall_score(Y_test, predict_label)
                dp = demographic_parity_difference(Y_test, predict_label, sensitive_features=s_test)
                eo = equalized_odds_difference(Y_test, predict_label, sensitive_features=s_test)
                ge = generalized_entropy_error(Y_test, predict_label)


                acc_list.append(acc)
                f1_list.append(f1)
                recall_list.append(recall)

                dp_list.append(dp)
                ge_list.append(ge)
                eo_list.append(eo)


            def avg_var_str(name, values):
                return (f'Average {name} of {filenames[d]}: {np.mean(values):.6f}\n'
                        f'Std {name} of {filenames[d]}: {np.std(values):.6f}\n')

            file.write(('alpha is ' + str(args.post_alpha) + '\n'))
            file.write(avg_var_str('accuracy', acc_list))
            file.write(avg_var_str('f1', f1_list))
            file.write(avg_var_str('recall', recall_list))
            file.write(avg_var_str('dp', dp_list))
            file.write(avg_var_str('eo', eo_list))
            file.write(avg_var_str('ge', ge_list))
            file.write('\n\n')
            print(filenames[d], 'done!')

        file.write('all done!!!!!')
        file.close()


if __name__ == "__main__":
    main()