import numpy as np
import pandas as pd
from folktables import ACSDataSource, ACSEmployment, ACSIncome
from sklearn.model_selection import train_test_split


def convert_to_here(features, label, group, test_size, random_state):
    group_ids = np.unique(group)
    groups_sizes = [np.count_nonzero(group == g) for g in group_ids]
    biggest_ids = group_ids[np.argsort(groups_sizes)[::-1][:2]]

    rows_to_keep = (group == biggest_ids[0]) | (group == biggest_ids[1])

    features = features[rows_to_keep]
    features = features - np.mean(features, axis=0, keepdims=True)
    features = features / np.std(features, axis=0, keepdims=True)

    features = pd.DataFrame(data=features)
    group = group[rows_to_keep] == biggest_ids[-1]
    label = label[rows_to_keep].astype(float)

    (
        features_train,
        features_test,
        sensitive_train,
        sensitive_test,
        labels_train,
        labels_test,
    ) = train_test_split(
        features, group, label, test_size=test_size, random_state=random_state
    )
    train_set = (features_train, sensitive_train, labels_train)
    test_set = (features_test, sensitive_test, labels_test)
    return train_set, test_set


def get_employment(test_size=0.2, random_state=0):
    data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person")
    acs_data = data_source.get_data(states=["AL"], download=True)
    features, label, group = ACSEmployment.df_to_numpy(acs_data)

    train_set, test_set = convert_to_here(
        features, label, group, test_size, random_state
    )

    return train_set, test_set


def get_income(test_size=0.2, random_state=0):
    data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person")
    acs_data = data_source.get_data(states=["AL"], download=True)
    features, label, group = ACSIncome.df_to_numpy(acs_data)

    train_set, test_set = convert_to_here(
        features, label, group, test_size, random_state
    )

    return train_set, test_set
