import pandas as pd
from sklearn.model_selection import train_test_split


def get_adult(
    test_size=0.2,
    random_state=0,
    include_race=True,
    include_sex=True,
    sensitive_attr="race",
    dummy=True,
):
    if sensitive_attr not in ["race", "sex"]:
        raise ValueError("sensitive_attr must be either race or sex")

    features, labels = load_adult_raw()
    features, labels = clean_adult(features, labels)
    sensitive = extract_sensitive(features, sensitive_attr)
    features, labels = preprocess_adult(
        features,
        labels,
        include_race=include_race,
        include_sex=include_sex,
        dummy=dummy,
    )

    # return train_test_split(
    #     features, sensitive, labels, test_size=test_size, random_state=random_state
    # )
    (
        features_train,
        features_test,
        sensitive_train,
        sensitive_test,
        labels_train,
        labels_test,
    ) = train_test_split(
        features, sensitive, labels, test_size=test_size, random_state=random_state
    )
    train_set = (features_train, sensitive_train, labels_train)
    test_set = (features_test, sensitive_test, labels_test)

    return train_set, test_set


def extract_sensitive(features, sensitive_attr):
    value = {"race": "Black", "sex": "Female"}[sensitive_attr]
    sensitive = pd.get_dummies(features[sensitive_attr])[value]
    sensitive = sensitive.values
    return sensitive.astype(bool)


def load_adult_raw():
    try:
        features = pd.read_csv("datasets//adult_features_raw.csv")
        labels = pd.read_csv("datasets//adult_labels_raw.csv")
    except FileNotFoundError:
        from ucimlrepo import fetch_ucirepo

        # fetch dataset
        adult = fetch_ucirepo(id=2)

        # data (as pandas dataframes)
        features = adult.data.features
        labels = adult.data.targets

        features.to_csv("datasets/adult_features_raw.csv", index=False)
        labels.to_csv("datasets/adult_labels_raw.csv", index=False)
    return features, labels


def preprocess_adult(
    features, labels, include_race=False, include_sex=False, dummy=True
):
    drop_col = ["fnlwgt", "workclass", "occupation", "education", "native-country"]
    # drop_col = drop_col + ["relationship", "marital-status"]
    if not include_sex:
        drop_col.append("sex")
    if not include_race:
        drop_col.append("race")
    norm_col = [
        "age",
        "education-num",
        "capital-gain",
        "capital-loss",
        "hours-per-week",
    ]

    # Drop irrelevant columns
    features = features.drop(columns=drop_col)

    # Normalize continous columns
    for col in norm_col:
        features[col] = (features[col] - features[col].mean()) / features[col].std()

    # Change income to binary variable
    labels = labels.map(lambda x: 1.0 if ">" in x else 0.0).values

    # Convert categorical features to one-hot encoding
    if dummy:
        features = pd.get_dummies(features)

    # features = features.drop(columns=["race_Black", "sex_Female"])

    return features, labels


def clean_adult(features, labels):
    """Drop Nans, and keep only Black and White under Race"""
    data = features.join(labels)

    df_obj = data.select_dtypes(["object"])
    data[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
    data = data.replace("?", pd.NA).dropna()

    ix = data["race"].isin(["White", "Black"])
    data = data.loc[ix, :]

    labels = data.iloc[:, -1]  # DataFrame with the last column
    features = data.iloc[:, :-1]

    object_columns = features.select_dtypes(include="object").columns
    features[object_columns] = features[object_columns].astype("category")
    return features, labels
