import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.preprocessing import LabelEncoder
import os
import argparse
from fairlearn.datasets import fetch_acs_income

from utils import *

def main(data_name, test_data, size_train, size_test, tree_depth):
    # === Load and preprocess data ===
    if data_name == "credit":
        # Load credit dataset for training
        df = pd.read_csv("datasets/default_credit.csv", header=0) # ground truth: 1 (default), 0 (no default)
        df = df.rename(columns={"default payment next month": "target"})
        df = df.apply(pd.to_numeric, errors='coerce').dropna()
        sensitive_col = "SEX"  # Sensitive attribute: column name ("sex" for gender)
        sensitive = df.columns.get_loc(sensitive_col)
        X = df.drop(columns="target")
        y = df["target"]
        A = df.iloc[:, sensitive]
        
        # Load synthetic test data
        if test_data == "real":
            X_test = X.iloc[:size_test]
            y_test = y.iloc[:size_test]
            A_test = A.iloc[:size_test]
        else:
            df_syn = pd.read_csv(f"{test_data}", header=0)
            df_syn = df_syn.rename(columns={"default payment next month": "target"})
            X_test = df_syn.drop(columns="target").iloc[:size_test]
            y_test = df_syn["target"].iloc[:size_test]
            A_test = df_syn.iloc[:, sensitive].iloc[:size_test]
    
    elif data_name == "acs":
        # Load ACS income dataset
        df = fetch_acs_income(as_frame=True).frame
        print(f"Number of rows in df: {len(df)}")
        df = df.dropna()
        target_col = 'PINCP'  # Replace with actual target column name
        threshold = 50000  # Income threshold
        sensitive_col = "SEX"  # Sensitive attribute: column name ("sex" for gender)
        sensitive = df.columns.get_loc(sensitive_col)
        df['target'] = (df[target_col] > threshold).astype(int)
        X = df.drop(columns=[target_col, 'target'])
        y = df["target"]
        A = df.iloc[:, sensitive]

        # Load synthetic test data
        if test_data == "real":
            X_test = X.iloc[:size_test]
            y_test = y.iloc[:size_test]
            A_test = A.iloc[:size_test]
        else:
            df_syn = pd.read_csv(f"{test_data}", header=0)
            df_syn = df_syn.dropna()
            df_syn['target'] = (df_syn[target_col] > threshold).astype(int)
            class_filename = f"{os.path.splitext(test_data)[0]}_class.csv"
            df_syn.to_csv(class_filename, index=False)
            X_test = df_syn.drop(columns=[target_col, 'target']).iloc[:size_test]
            y_test = df_syn["target"].iloc[:size_test]
            A_test = df_syn.iloc[:, sensitive].iloc[:size_test]

    elif data_name == "adult":
        # Load Adult dataset
        df = pd.read_csv("datasets/adult.csv", header=0)
        df = df.rename(columns={"class": "target"})
        sensitive_col = "sex"  # Sensitive attribute: column name ("sex" for gender)
        sensitive = df.columns.get_loc(sensitive_col)
        # preprocess target and sensitive attributes
        # Biased case:  (prob. of >50k is higher for Male)
        df["target"] = df["target"].replace({"<=50K": 0, ">50K": 1})
        if "sex" in df.columns:
            df["sex"] = df["sex"].replace({"Male": 1, "Female": 2})
        for col in df.select_dtypes(include=['object', 'category']).columns:
            le = LabelEncoder()
            df[col] = le.fit_transform(df[col])

        X = df.drop(columns="target")
        y = df["target"]
        A = df.iloc[:, sensitive]      

        # Load synthetic test data
        # preprocess target and sensitive attributes
        df_syn = pd.read_csv(f"{test_data}", header=0)
        df_syn = df_syn.rename(columns={"class": "target"})
        df_syn["target"] = df_syn["target"].replace({"<=50K": 0, ">50K": 1})
        if "sex" in df_syn.columns:
            df_syn["sex"] = df_syn["sex"].replace({"Male": 1, "Female": 2})
        for col in df_syn.select_dtypes(include=['object', 'category']).columns:
            le = LabelEncoder()
            df_syn[col] = le.fit_transform(df_syn[col])

        class_filename = f"{os.path.splitext(test_data)[0]}_class.csv"
        df_syn.to_csv(class_filename, index=False)
        X_test = df_syn.drop(columns="target").iloc[:size_test]
        y_test = df_syn["target"].iloc[:size_test]
        A_test = df_syn.iloc[:, sensitive].iloc[:size_test]
    else:
        raise ValueError(f"Unknown dataset: {data_name}")

    # === Use the first "testsize" points as the test data
    # use trainsize_fit points for training the model
    X_train = X.iloc[size_test:size_test + size_train]
    y_train = y.iloc[size_test:size_test + size_train]
    A_train = A.iloc[size_test:size_test + size_train]
    if data_name == "adult":
        # Align columns in case of mismatch
        X_train, X_test = X_train.align(X_test, join='left', axis=1, fill_value=0)

    # Also prepare fair versions
    # Replace all values in the gender column with 0 for training (fair model)
    X_train_fair = X_train.copy()
    X_train_fair.iloc[:, sensitive] = 0
    X_test_fair = X_test.copy()  # keep original test data (with real gender values)
    feature_names_all = [str(i) for i in range(X.shape[1])]

    # === Unfair model (with gender) ===
    # Manipulate the training labels to favor sensitive attribute = 1 (e.g., males)
    y_train_biased = y_train.copy()
    # For samples where gender == 1, set target to 1 (default), for others keep original
    y_train_biased[A_train == 1] = 1

    clf_biased = DecisionTreeClassifier(max_depth=tree_depth, random_state=42)
    clf_biased.fit(X_train, y_train_biased)
    tree_text_unfair = export_leaf_labels(clf_biased, feature_names=feature_names_all)

    with open(f"Model/{data_name}_dt_unfair.txt", "w") as f:
        f.write(tree_text_unfair)

    y_pred_biased = clf_biased.predict(X_test)
    y_proba_biased = clf_biased.predict_proba(X_test)[:, 1]

    dp_diff_b, rate_m_b, rate_f_b = compute_dp(y_pred_biased, A_test)
    eqod_diff_b, tpr_b, fpr_b = compute_eqod(y_test, y_pred_biased, A_test)
    mrd_diff_b, res_m_b, res_f_b = compute_mrd(y_test, y_pred_biased, A_test)
    #mrd_diff_b, res_m_b, res_f_b = compute_mrd(y_test, y_proba_biased, A_test)
    

    # === Fair model (without gender) ===
    clf_fair = DecisionTreeClassifier(max_depth=tree_depth, random_state=42)
    clf_fair.fit(X_train_fair, y_train)
    # Print the tree structure as text
    # print(export_text(clf_fair, feature_names=feature_names_all))
    tree_text_fair = export_leaf_labels(clf_fair, feature_names=feature_names_all)

    with open(f"Model/{data_name}_dt_fair.txt", "w") as f:
        f.write(tree_text_fair)

    # decision path for a sample
    # print_sample_and_decision_path(clf_fair, X_test_fair, feature_names_all, sample_idx=0)

    y_pred_fair = clf_fair.predict(X_test_fair)
    y_proba_fair = clf_fair.predict_proba(X_test_fair)[:, 1]

    dp_diff_f, rate_m_f, rate_f_f = compute_dp(y_pred_fair, A_test)
    eqod_diff_f, tpr_f, fpr_f = compute_eqod(y_test, y_pred_fair, A_test)
    mrd_diff_f, res_m_f, res_f_f = compute_mrd(y_test, y_pred_fair, A_test)
    # mrd_diff_f, res_m_f, res_f_f = compute_mrd(y_test, y_proba_fair, A_test)

    # === Print comparison ===
    print("=== Biased Tree ===")
    print(f"Demographic Parity Difference:    {dp_diff_b:.4f} (M: {rate_m_b:.4f}, F: {rate_f_b:.4f})")
    print(f"Equal Odds Difference (TPR DIFF): {eqod_diff_b[0]:.4f}, (M: {tpr_b[0]:.4f}, F: {tpr_b[1]:.4f})")
    print(f"Equal Odds Difference (FPR DIFF): {eqod_diff_b[1]:.4f}, (M: {fpr_b[0]:.4f}, F: {fpr_b[1]:.4f})")
    print(f"Mean Residual Difference:       {mrd_diff_b:.4f} (M: {res_m_b:.4f}, F: {res_f_b:.4f})\n")

    print("=== Fair Tree (trained without sensitive attribute) ===")
    print(f"Demographic Parity Difference:  {dp_diff_f:.4f} (M: {rate_m_f:.4f}, F: {rate_f_f:.4f})")
    print(f"Equal Odds Difference (TPR DIFF): {eqod_diff_f[0]:.4f}, (M: {tpr_f[0]:.4f}, F: {tpr_f[1]:.4f})")
    print(f"Equal Odds Difference (FPR DIFF): {eqod_diff_f[1]:.4f}, (M: {fpr_f[0]:.4f}, F: {fpr_f[1]:.4f})")
    print(f"Mean Residual Difference:       {mrd_diff_f:.4f} (M: {res_m_f:.4f}, F: {res_f_f:.4f})")

    print("\n\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Synthetic credit data generator")
    parser.add_argument("--data", type=str, default="credit", help="Name of the dataset (without .csv)")
    parser.add_argument("--testdata", type=str, default="real", help="Use synthetic data instead of real data")
    parser.add_argument("--trainsize", type=int, default=1000, help="Number of training samples for decision tree")
    parser.add_argument("--testsize", type=int, default=1000, help="Number of test samples for decision tree")
    parser.add_argument("--tree_depth", type=int, default=10, help="Depth of the decision tree")

    args = parser.parse_args()
    main(args.data, args.testdata, args.trainsize, args.testsize, args.tree_depth)
