from folktables import ACSIncome, ACSDataSource

from .base import Data
from utils import DATA_DIR


class NewAdult(Data):
    name = "newadult"
    feat_dim = 24
    sens_dim = 13
    simple_sens_cols = [0, 1]

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def setup(self, stage: str):
        data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person', root_dir=DATA_DIR)
        ca_data = data_source.get_data(states=["CA"], download=True)
        ca_features, ca_labels, _ = ACSIncome.df_to_pandas(ca_data)

        # From Retiring Adult: New Datasets for Fair Machine Learning
        # AGEP: Age of person   COW: Class of worker    SCHL: Educational attainment
        # MAR: Marital status   OCCP: Occupation        POBP: Place of birth
        # RELP: Relationship    WKHP: Usual hours worked per week past 12 months
        # SEX: 1 - male, 2 - female                     RA1CP: Recorded detailed race code
        # PINCP: Total person's income: > 50 000
        df = ca_features.join(ca_labels)

        super()._preprocess_df(df,
                               sens_columns=["SEX", "AGEP", "MAR", "RAC1P"],
                               label_column=["PINCP"],
                               drop_columns=["POBP", "RELP"],
                               categorical_values=["COW", "MAR", "SEX", "RAC1P"],
                               mapping={"RAC1P": {3: 8, 4: 8, 5: 8, 7: 8}},
                               mapping_label={},
                               normalise_columns=["AGEP", "SCHL", "OCCP", "WKHP"],
                               drop_rows={})
