import os

import numpy as np
import pandas as pd
from aif360.datasets import StandardDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

"""
Code from 
"Beyond Adult and COMPAS: Fairness in Multi-Class Prediction via Information Projection"
https://github.com/HsiangHsu/Fair-Projection
"""


def get_idx_w_protected(feature_names):
    return list(set(range(len(feature_names))))


def get_idx_protected(feature_names, protected_attrs):
    protected_attr_idx = [feature_names.index(x) for x in protected_attrs]
    idx_protected = list(set(protected_attr_idx))
    return idx_protected


def load_any(
    ds_name,
    test_size=0.3,
    seed=0,
):
    if ds_name == "adult":
        protected_attrs = ["race"]
        label_name = "income"
        df = load_data("adult")
    if ds_name == "compas":
        protected_attrs = ["race"]
        label_name = "is_recid"
        df = load_data("compas")
    if ds_name == "hsls":
        protected_attrs = ["racebin"]
        label_name = "gradebin"
        df = load_hsls_imputed()

    dataset_orig_train, dataset_orig_test = train_test_split(
        df, test_size=test_size, random_state=seed
    )
    dataset_orig_train = StandardDataset(
        dataset_orig_train,
        label_name=label_name,
        favorable_classes=[1],
        protected_attribute_names=protected_attrs,
        privileged_classes=[[1]],
    )
    dataset_orig_test = StandardDataset(
        dataset_orig_test,
        label_name=label_name,
        favorable_classes=[1],
        protected_attribute_names=protected_attrs,
        privileged_classes=[[1]],
    )

    idx_features = get_idx_w_protected(dataset_orig_train.feature_names)
    idx_protected = get_idx_protected(dataset_orig_train.feature_names, protected_attrs)

    features_train, labels_train = (
        dataset_orig_train.features[:, idx_features],
        dataset_orig_train.labels.ravel(),
    )
    features_test, labels_test = (
        dataset_orig_test.features[:, idx_features],
        dataset_orig_test.labels.ravel(),
    )
    sensitive_train = dataset_orig_train.features[:, idx_protected].ravel()

    sensitive_test = dataset_orig_test.features[:, idx_protected].ravel()

    features_train = pd.DataFrame(data=features_train)
    features_test = pd.DataFrame(data=features_test)
    sensitive_train = sensitive_train == 1

    # sensitive = extract_sensitive(df, protected_attrs)
    # labels = extract_sensitive(df, label_name)

    # (
    #     features_train,
    #     features_test,
    #     sensitive_train,
    #     sensitive_test,
    #     labels_train,
    #     labels_test,
    # ) = train_test_split(
    #     features, sensitive, labels, test_size=test_size, random_state=random_state
    # )

    train_set = (features_train, sensitive_train, labels_train)
    test_set = (features_test, sensitive_test, labels_test)

    return train_set, test_set


#  method for loading different datasets
def load_data(name="adult"):
    # % Processing for UCI-ADULT
    if name == "adult":
        file = os.path.join("exisiting_ds", "adult.data")
        fileTest = os.path.join("exisiting_ds", "adult.test")

        df = pd.read_csv(file, header=None, sep=",\\s+", engine="python")
        dfTest = pd.read_csv(
            fileTest, header=None, skiprows=1, sep=",\\s+", engine="python"
        )

        columnNames = [
            "age",
            "workclass",
            "fnlwgt",
            "education",
            "education-num",
            "marital-status",
            "occupation",
            "relationship",
            "race",
            "gender",
            "capital-gain",
            "capital-loss",
            "hours-per-week",
            "native-country",
            "income",
        ]

        df.columns = columnNames
        dfTest.columns = columnNames

        # df = df.append(dfTest)
        df = pd.concat([df, dfTest], ignore_index=True)

        # drop columns that won't be used
        dropCol = ["fnlwgt", "workclass", "occupation"]
        df.drop(dropCol, inplace=True, axis=1)

        # keep only entries marked as ``White'' or ``Black''
        ix = df["race"].isin(["White", "Black"])
        df = df.loc[ix, :]

        # binarize race
        # Black = 0; White = 1
        df.loc[:, "race"] = df["race"].apply(lambda x: 0 if x == "White" else 1)

        # binarize gender
        # Female = 0; Male = 1
        df.loc[:, "gender"] = df["gender"].apply(lambda x: 1 if x == "Male" else 0)

        # binarize income
        # '>50k' = 1; '<=50k' = 0
        df.loc[:, "income"] = df["income"].apply(lambda x: 1 if x[0] == ">" else 0)

        # drop "education" and native-country (education already encoded in education-num)
        features_to_drop = ["education", "native-country"]
        df.drop(features_to_drop, inplace=True, axis=1)

        numerical_features = set(df._get_numeric_data().columns)
        numerical_features.add("income")
        numerical_features.add("race")

        # create one-hot encoding
        categorical_features = list(set(df) - numerical_features)
        df = pd.concat(
            [df, pd.get_dummies(df[categorical_features])], axis=1, sort=False
        )
        df.drop(categorical_features, inplace=True, axis=1)

        # reset index
        df.reset_index(inplace=True, drop=True)

    # % Processing for COMPAS
    if name == "compas":
        file = os.path.join("exisiting_ds", "compas-scores-two-years.csv")
        df = pd.read_csv(file, index_col=0)

        # select features for analysis
        df = df[
            [
                "age",
                "c_charge_degree",
                "race",
                "sex",
                "priors_count",
                "days_b_screening_arrest",
                "is_recid",
                "c_jail_in",
                "c_jail_out",
            ]
        ]

        # drop missing/bad features (following ProPublica's analysis)
        # ix is the index of variables we want to keep.

        # Remove entries with inconsistent arrest information.
        ix = df["days_b_screening_arrest"] <= 30
        ix = (df["days_b_screening_arrest"] >= -30) & ix

        # remove entries entries where compas case could not be found.
        ix = (df["is_recid"] != -1) & ix

        # remove traffic offenses.
        ix = (df["c_charge_degree"] != "O") & ix

        # trim dataset
        df = df.loc[ix, :]

        # create new attribute "length of stay" with total jail time.
        df["length_of_stay"] = (
            pd.to_datetime(df["c_jail_out"]) - pd.to_datetime(df["c_jail_in"])
        ).apply(lambda x: x.days)

        # drop 'c_jail_in' and 'c_jail_out'
        # drop columns that won't be used
        dropCol = ["c_jail_in", "c_jail_out", "days_b_screening_arrest"]
        df.drop(dropCol, inplace=True, axis=1)

        # keep only African-American and Caucasian
        df = df.loc[df["race"].isin(["African-American", "Caucasian"]), :]

        # binarize race
        # African-American: 0, Caucasian: 1
        df.loc[:, "race"] = df["race"].apply(lambda x: 1 if x == "Caucasian" else 0)

        # binarize gender
        # Female: 1, Male: 0
        df.loc[:, "sex"] = df["sex"].apply(lambda x: 1 if x == "Male" else 0)

        # rename columns 'sex' to 'gender'
        df.rename(index=str, columns={"sex": "gender"}, inplace=True)

        # binarize degree charged
        # Misd. = -1, Felony = 1
        df.loc[:, "c_charge_degree"] = df["c_charge_degree"].apply(
            lambda x: 1 if x == "F" else -1
        )

        # reset index
        df.reset_index(inplace=True, drop=True)

    # TODO: add other datasets here

    return df


def load_hsls_imputed(vars=[]):
    ## group_feature can be either 'sexbin' or 'racebin'
    ## load csv
    file = os.path.join("exisiting_ds", "hsls_df_knn_impute_past_v2.pkl")
    df = pd.read_pickle(file)

    ## if no variables specified, include all variables
    if vars != []:
        df = df[vars]

    ## Setting NaNs to out-of-range entries
    ## entries with values smaller than -7 are set as NaNs
    df[df <= -7] = np.nan

    ## Dropping all rows or columns with missing values
    ## this step significantly reduces the number of samples
    df = df.dropna()

    ## Creating racebin & gradebin & sexbin variables
    ## X1SEX: 1 -- Male, 2 -- Female, -9 -- NaN -> Preprocess it to: 0 -- Female, 1 -- Male, drop NaN
    ## X1RACE: 0 -- BHN, 1 -- WA
    df["gradebin"] = df["grade9thbin"]
    df["racebin"] = np.logical_not(
        np.logical_or(
            ((df["studentrace"] * 7).astype(int) == 7).values,
            ((df["studentrace"] * 7).astype(int) == 1).values,
        )
    ).astype(int)
    df["sexbin"] = df["studentgender"].astype(int)

    ## Dropping race and 12th grade data just to focus on the 9th grade prediction ##
    df = df.drop(
        columns=["studentgender", "grade9thbin", "grade12thbin", "studentrace"]
    )

    ## Scaling ##
    scaler = MinMaxScaler()
    df = pd.DataFrame(scaler.fit_transform(df), columns=df.columns, index=df.index)

    ## Balancing data to have roughly equal race=0 and race =1 ##
    # df = balance_data(df, group_feature)
    return df
