import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

# Preprocessing from https://github.com/HsiangHsu/Fair-Projection/blob/main/baseline-methods/DataLoader.py.


def get_hsls(
    test_size=0.2,
    random_state=0,
    sensitive_attr="racebin",
):
    if sensitive_attr not in ["racebin", "sexbin"]:
        raise ValueError("sensitive_attr must be either racebin or sexbin")

    df = load_hsls_raw()
    df = clean_hsls(df)
    sensitive = extract_sensitive(df, sensitive_attr)
    features, labels = preprocess_hsls(df)

    (
        features_train,
        features_test,
        sensitive_train,
        sensitive_test,
        labels_train,
        labels_test,
    ) = train_test_split(
        features, sensitive, labels, test_size=test_size, random_state=random_state
    )
    train_set = (features_train, sensitive_train, labels_train)
    test_set = (features_test, sensitive_test, labels_test)

    return train_set, test_set


def extract_sensitive(df, sensitive_attr):
    sensitive = df[sensitive_attr].values == 0
    return sensitive


def load_hsls_raw():
    df = pd.read_pickle("datasets/hsls_knn_impute.pkl")
    return df


def preprocess_hsls(df):
    labels = df["gradebin"].values
    features = df.drop("gradebin", axis=1)

    return features, labels


def clean_hsls(df):
    # Setting NaNs to out-of-range entries
    # entries with values smaller than -7 are set as NaNs
    df[df <= -7] = np.nan

    # Dropping all rows or columns with missing values
    # this step significantly reduces the number of samples
    df = df.dropna()

    # Creating racebin & gradebin & sexbin variables
    # X1SEX: 1 -- Male, 2 -- Female, -9 -- NaN -> Preprocess it to: 0 -- Female, 1 -- Male, drop NaN
    # X1RACE: 0 -- BHN, 1 -- WA
    df["gradebin"] = df["grade9thbin"]
    df["racebin"] = np.logical_or(
        ((df["studentrace"] * 7).astype(int) == 7).values,
        ((df["studentrace"] * 7).astype(int) == 1).values,
    ).astype(int)
    df["sexbin"] = df["studentgender"].astype(int)

    # Dropping race and 12th grade data just to focus on the 9th grade prediction ##
    df = df.drop(
        columns=["studentgender", "grade9thbin", "grade12thbin", "studentrace"]
    )

    scaler = MinMaxScaler()
    df = pd.DataFrame(scaler.fit_transform(df), columns=df.columns, index=df.index)
    return df
