import pandas as pd
from sklearn.model_selection import train_test_split


def get_compas(
    test_size=0.2,
    random_state=0,
    sensitive_attr="race",
):
    if sensitive_attr not in ["race", "gender"]:
        raise ValueError("sensitive_attr must be either race or gender")

    df = load_compas_raw()
    df = clean_compas(df)
    sensitive = extract_sensitive(df, sensitive_attr)
    features, labels = preprocess_compas(df)

    (
        features_train,
        features_test,
        sensitive_train,
        sensitive_test,
        labels_train,
        labels_test,
    ) = train_test_split(
        features, sensitive, labels, test_size=test_size, random_state=random_state
    )
    train_set = (features_train, sensitive_train, labels_train)
    test_set = (features_test, sensitive_test, labels_test)

    return train_set, test_set


def extract_sensitive(df, sensitive_attr):
    sensitive = df[sensitive_attr].values.astype(bool)
    return sensitive


def load_compas_raw():
    df = pd.read_csv("datasets/compas-scores-two-years.csv", index_col=0)
    return df


def preprocess_compas(df):
    labels = df["target"].values
    features = df.drop("target", axis=1)

    return features, labels


def clean_compas(df):
    # Same preprocessing as https://github.com/HsiangHsu/Fair-Projection/blob/main/baseline-methods/DataLoader.py.
    # select features for analysis
    # pylint: enable=line-too-long
    df = df[
        [
            "age",
            "c_charge_degree",
            "race",
            "sex",
            "priors_count",
            "days_b_screening_arrest",
            "is_recid",
            "c_jail_in",
            "c_jail_out",
        ]
    ]

    # drop missing/bad features (following ProPublica's analysis)
    # ix is the index of variables we want to keep.

    # Remove entries with inconsistent arrest information.
    ix = df["days_b_screening_arrest"] <= 30
    ix = (df["days_b_screening_arrest"] >= -30) & ix

    # remove entries entries where compas case could not be found.
    ix = (df["is_recid"] != -1) & ix

    # remove traffic offenses.
    ix = (df["c_charge_degree"] != "O") & ix

    # trim dataset
    df = df.loc[ix, :]

    # create new attribute "length of stay" with total jail time.
    df["length_of_stay"] = (
        pd.to_datetime(df["c_jail_out"]) - pd.to_datetime(df["c_jail_in"])
    ).apply(lambda x: x.days)

    # drop 'c_jail_in' and 'c_jail_out'
    # drop columns that won't be used
    drop_col = ["c_jail_in", "c_jail_out", "days_b_screening_arrest"]
    df.drop(drop_col, inplace=True, axis=1)

    # keep only African-American and Caucasian
    df = df.loc[df["race"].isin(["African-American", "Caucasian"]), :]

    # binarize race
    # African-American: 0, Caucasian: 1
    df.loc[:, "race"] = df["race"].apply(lambda x: 1 if x == "Caucasian" else 0)

    # binarize gender
    # Female: 1, Male: 0
    df.loc[:, "sex"] = df["sex"].apply(lambda x: 1 if x == "Male" else 0)

    # rename columns 'sex' to 'gender'
    df.rename(index=str, columns={"sex": "gender"}, inplace=True)

    # binarize degree charged
    # Misd. = -1, Felony = 1
    df.loc[:, "c_charge_degree"] = df["c_charge_degree"].apply(
        lambda x: 1 if x == "F" else -1
    )

    # reset index
    df.reset_index(inplace=True, drop=True)

    df = df.rename({"is_recid": "target"}, axis=1)

    df["c_charge_degree"] = df["c_charge_degree"].astype(int)
    df["race"] = df["race"].astype(int)
    df["gender"] = df["gender"].astype(int)
    return df
