import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize, StandardScaler
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from aif360.sklearn.metrics import generalized_entropy_error
from fairlearn.metrics import demographic_parity_difference, equalized_odds_difference
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import os
import random
from mlp import MLP
import warnings
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, roc_auc_score, recall_score, f1_score, precision_score
import argparse
import time
parser = argparse.ArgumentParser(description='FairGBFC')
parser.add_argument('--sen_index', type=int, default=1, help='index of sensitive attribute')
parser.add_argument('--max_iter', type=int, default=500, help='iter num')
parser.add_argument('--learning_rate_init', type=float, default=0.001, help='learning rate')
parser.add_argument('--lam1', type=float, default=10, help='lambda1')
parser.add_argument('--lam2', type=float, default=0.1, help='lambda2')
args = parser.parse_args()
warnings.filterwarnings("ignore")
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

class FERMI(nn.Module):
    def __init__(self, device, model, X_train, Y_train, S_train,
                 batch_size=64, epochs=2000, lam=10):
        super(FERMI, self).__init__()
        self.device = device
        self.model = model  

        self.X_train = X_train
        self.Y_train = Y_train
        self.S_train = S_train

        self.batch_size = batch_size
        self.epochs = epochs

        self.n = X_train.shape[0]
        self.d = X_train.shape[1]
        self.k = S_train.shape[1]
        self.m = 2  # binary classification with softmax over 2 classes

        # fairness-related weights
        self.W = nn.Parameter(torch.zeros(self.k, self.m, device=self.device))

        # sensitive attribute distribution
        sums = self.S_train.sum(axis=0) / self.n
        self.p_s0 = sums[0]
        self.p_s1 = sums[1]

        final_entries = [1.0 / np.sqrt(item) for item in sums]
        self.P_s = np.diag(sums)
        self.P_s_sqrt_inv = torch.from_numpy(np.diag(final_entries)).double()

        self.lam = lam

    def forward(self, X):
        _, logits = self.model(X.float())
        return torch.softmax(logits, dim=1)  # (N, 2)

    def fairness_regularizer(self, X, S, f_divergence='Chi2'):
        probs = self.forward(X)  # (N, 2)
        Y_hat = probs[:, 1]      # p(y=1|x)
        Y_hat0 = probs[:, 0]     # p(y=0|x)

        p_y1 = torch.mean(Y_hat)
        p_y0 = 1 - p_y1
        p_s0 = torch.mean(S[:, 0])
        p_s1 = torch.mean(S[:, 1])

        p_y1s0 = torch.mean(Y_hat * S[:, 0])
        p_y1s1 = torch.mean(Y_hat * S[:, 1])
        p_y0s0 = torch.mean(Y_hat0 * S[:, 0])
        p_y0s1 = torch.mean(Y_hat0 * S[:, 1])

        W = self.W.double()
        reg = 0
        if f_divergence == 'Chi2':
            reg += 2 * p_y1s1 * W[1][1] - self.p_s1 * p_y1 * W[1][1] ** 2 + self.p_s1 * p_y1 - 2 * p_y1s1
            reg += 2 * p_y0s0 * W[0][0] - self.p_s0 * p_y0 * W[0][0] ** 2 + self.p_s0 * p_y0 - 2 * p_y0s0
            reg += 2 * p_y1s0 * W[1][0] - self.p_s0 * p_y1 * W[1][0] ** 2 + self.p_s0 * p_y1 - 2 * p_y1s0
            reg += 2 * p_y0s1 * W[0][1] - self.p_s1 * p_y0 * W[0][1] ** 2 + self.p_s1 * p_y0 - 2 * p_y0s1
        elif f_divergence == 'KL':
            reg += p_y1s1 * W[1][1] - self.p_s1 * p_y1 * torch.exp(W[1][1] - 1)
            reg += p_y0s0 * W[0][0] - self.p_s0 * p_y0 * torch.exp(W[0][0] - 1)
            reg += p_y1s0 * W[1][0] - self.p_s0 * p_y1 * torch.exp(W[1][0] - 1)
            reg += p_y0s1 * W[0][1] - self.p_s1 * p_y0 * torch.exp(W[0][1] - 1)

        return self.lam * reg

def one_hot_encode_sensitive(s):
    s = s.astype(int)
    num_classes = np.max(s) + 1
    one_hot = np.zeros((len(s), num_classes))
    one_hot[np.arange(len(s)), s] = 1
    return one_hot

def fair_training(fermi, batch_size, epochs, device, initial_epochs=300, initial_learning_rate=1,
                  lam=0.1, learning_rate_min=0.01, learning_rate_max=0.01, f_divergence='Chi2'):

    X = fermi.X_train
    S_Matrix = fermi.S_train
    Y = fermi.Y_train

    criterion = torch.nn.CrossEntropyLoss()

    # 优化器更新模型参数和W参数
    # 注意：fermi.model.parameters() 是 MLP 的所有参数
    params = list(fermi.model.parameters()) + [fermi.W]
    optimizer = torch.optim.Adam(params, lr=initial_learning_rate)

    for ep in range(epochs + initial_epochs):

        number_of_iterations = X.shape[0] // batch_size

        for i in range(number_of_iterations):

            start = i * batch_size
            end = (i + 1) * batch_size

            current_batch_X = X[start:end]
            current_batch_Y = Y[start:end]
            current_batch_S = S_Matrix[start:end]

            XTorch = torch.from_numpy(current_batch_X).float().to(device)
            YTorch = torch.from_numpy(current_batch_Y).long().to(device)
            STorch = torch.from_numpy(current_batch_S).float().to(device)

            _, logits = fermi.model(XTorch)  # (batch_size, 2)

            if ep < initial_epochs:
                loss_min = criterion(logits, YTorch)
            else:
                loss_min = criterion(logits, YTorch) + fermi.fairness_regularizer(XTorch, STorch, f_divergence)

            optimizer.zero_grad()
            loss_min.backward()
            optimizer.step()

    return fermi.model.state_dict(), fermi.W.detach()


def load_data(name):
    src_path = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(src_path, '..', 'Datasets/processed dataset/')
    name = name + '.csv'
    print(dataset_path)
    # 设置敏感属性索引
    if name == 'credit_approval.csv':  # Gender
        sen_index = 1
        #在属性中的列下标
        cat_indices = [0, 3, 4, 5, 6, 8, 9, 11, 12]
        cont_indices = [1, 2, 7, 10, 13, 14]
    elif name == 'adult.csv':  # sex
        sen_index = 8
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'credit_default.csv':  # SEX
        sen_index = 2#这是包括了标签的
        cat_indices = [1, 2, 3, 5, 6, 7, 8, 9, 10]
        cont_indices = [0, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    elif name == 'law_admission.csv':  # gender
        sen_index = 6
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'german.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
        cont_indices = [3, 4, 5]
    
    elif name == 'por.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]

    # 加载数据
    data_path = os.path.join(dataset_path, name)
    if not os.path.exists(data_path):
         print(f"Error: File not found at {data_path}")
         return None, None, None, None
    data = pd.read_csv(data_path)

    label_col = data.iloc[:, 0]
    sensi_col = data.iloc[:, sen_index]

    # 输出统计信息
    print(f"\nDataset: {name}")
    
    # 敏感属性比例
    sen_col_ratio = sensi_col.value_counts(normalize=True)
    print(f"\n✅ 敏感属性列的值比例：")
    print(sen_col_ratio)

    # 标签分布
    col0_ratio = label_col.value_counts(normalize=True)
    print("\n✅ 标签的值比例：")
    print(col0_ratio)

    # DP值计算
    pos_label = 1
    groups = sensi_col.unique()
    if len(groups) >= 2:
        g1, g2 = groups[0], groups[1]
        g1_data = data[sensi_col == g1]
        g2_data = data[sensi_col == g2]

        g1_pos_rate = (g1_data.iloc[:, 0] == pos_label).sum() / len(g1_data)
        g2_pos_rate = (g2_data.iloc[:, 0] == pos_label).sum() / len(g2_data)
        dp = abs(g1_pos_rate - g2_pos_rate)
        print(f"\n✅ DP值（|P({g1}, 正例) - P({g2}, 正例)|）：{dp:.6f}")

        # 输出更详细的统计信息
        g1_total = len(g1_data)
        g1_pos = (g1_data.iloc[:, 0] == pos_label).sum()
        g2_total = len(g2_data)
        g2_pos = (g2_data.iloc[:, 0] == pos_label).sum()
        pos_total = (label_col == pos_label).sum()
        neg_total = (label_col != pos_label).sum()

        print(f"\n🔍 样本详细统计：")
        print(f" 敏感属性 - {g1} 样本总数：{g1_total}")
        print(f" 敏感属性 - {g1} 中正样本数：{g1_pos}")
        print(f" 敏感属性 - {g2} 样本总数：{g2_total}")
        print(f" 敏感属性 - {g2} 中正样本数：{g2_pos}")
        print(f"  - 正类样本总数：{pos_total}")
        print(f"  - 负类样本总数：{neg_total}")
    
    # 返回 Data 以及 元数据
    return data, sen_index - 1, cat_indices, cont_indices

def main():
    warnings.filterwarnings("ignore")  # ignore warning
    filenames = ['credit_approval', 'adult', 'credit_default', 'law_admission','german', 'por']
    n_splits = 5
    random_seed = 42
    np.random.seed(42)
    src_path = os.path.dirname(os.path.realpath('__file__'))

    lamdas = [0.1, 0.3, 0.5, 0.8, 1, 1.2, 1.5, 1.8, 2]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    result_path = os.path.join(src_path, 'FairGBFC/comparative_result/')

    for d in range(len(filenames)):
        dataset_name = filenames[d]
        file_path = result_path + '/Fair/FERM/' + str(dataset_name) + '.txt'
        file = open(file_path, mode='a')
        file.write(filenames[d] + '\n\n')
        data_frame, sen_index, cat_indices, cont_indices = load_data(dataset_name)
        args.sen_index = sen_index
        print(data_frame)
        print(args.sen_index)
        data = data_frame.values
        data_temp = []
        data_list = data.tolist()
        data = []
        for data_single in data_list:
            if data_single[1:] not in data_temp:
                data_temp.append(data_single[1:])
                data.append(data_single)
        data = np.array(data)
        numberSample = data.shape[0]

        minMax = MinMaxScaler()
        data = np.hstack((data[:, 0].reshape(numberSample, 1),
                            minMax.fit_transform(data[:, 1:])))
        train_data = data[:, 1:]
        train_target = data[:, 0]

        skf = StratifiedKFold(n_splits, shuffle=True, random_state=42)
        for lanbda in lamdas:
            args.lam1 = lanbda

            acc_list, f1_list, recall_list = [], [], []
            dp_list, ge_list, eo_list = [], [], []

            for train_index, test_index in skf.split(train_data, train_target):
                train, test = data[train_index], data[test_index]
                X_test = test[:, 1:]
                Y_test = test[:, 0]
                S_test = X_test[:, args.sen_index]
                X_train = train[:, 1:]
                s_train = X_train[:, args.sen_index]
                y_train = train[:, 0]

                start = time.time()
                mlp = MLP(arch=[train_data.shape[1], 128, 64, 32, 16, 2]).to(args.device)
                S_train_sub = one_hot_encode_sensitive(s_train)


                fermi = FERMI(model=mlp,
                                device=device,
                                X_train=X_train,
                                Y_train=y_train,
                                S_train=S_train_sub,
                                batch_size=min(256, X_train.shape[0]),
                                epochs=300,
                                lam=args.lam1)
                fermi.model.to(device)
                fermi.W = fermi.W.to(device)

                state_dict, W = fair_training(
                    fermi,
                    batch_size=X_train.shape[0],
                    epochs=300,
                    device=device,
                    initial_epochs=200,
                    initial_learning_rate=args.learning_rate_init,
                    lam=args.lam2,
                    f_divergence='Chi2'
                )
                fermi.model.eval()

                with torch.no_grad():
                    X_test_tensor = torch.from_numpy(X_test).float().to(args.device)
                    predict_label = fermi.model.predict(X_test_tensor).detach().cpu().numpy()


                # compute metrics
                acc = accuracy_score(Y_test, predict_label)
                f1 = f1_score(Y_test, predict_label)
                recall = recall_score(Y_test, predict_label)
                dp = demographic_parity_difference(Y_test, predict_label, sensitive_features=S_test)
                eo = equalized_odds_difference(Y_test, predict_label, sensitive_features=S_test)
                ge = generalized_entropy_error(Y_test, predict_label)


                acc_list.append(acc)
                f1_list.append(f1)
                recall_list.append(recall)

                dp_list.append(dp)
                ge_list.append(ge)
                eo_list.append(eo)


            def avg_var_str(name, values):
                return (f'Average {name} of {filenames[d]}: {np.mean(values):.6f}\n'
                        f'Std {name} of {filenames[d]}: {np.std(values):.6f}\n')

            # ✅ 输出所有指标（与Reweight格式保持一致）
            file.write(('lambda is ' + str(args.lam1) + '\n'))
            file.write(avg_var_str('accuracy', acc_list))
            file.write(avg_var_str('f1', f1_list))
            file.write(avg_var_str('recall', recall_list))
            file.write(avg_var_str('dp', dp_list))
            file.write(avg_var_str('eo', eo_list))
            file.write(avg_var_str('ge', ge_list))
            file.write('\n\n')
            print(filenames[d], 'done!')

        file.write('all done!!!!!\n')
        file.close()



if __name__ == "__main__":
    main()