from pathlib import Path

import argparse
import numpy as np
import pandas as pd
import sklearn.preprocessing as preprocessing

def load_uci(train_file, test_file):
    # https://archive.ics.uci.edu/ml/machine-learning-databases/adult/

    features = ["age", "workclass", "final_weight", "education", "education_label", "martial_status",
                "job", "relationship", "race", "sex", "capital_gain", "capital_loss",
                "hours_per_week", "country", "target"]

    train = pd.read_csv(train_file, names=features, sep=r'\s*,\s*',
                        engine='python', na_values="?")
    test = pd.read_csv(test_file, names=features, sep=r'\s*,\s*',
                       engine='python', na_values="?", skiprows=1)
    return train, test

def transform_uci_features(df):
    binary_data = pd.get_dummies(df)
    feature_cols = binary_data[binary_data.columns[:-2]]
    scaler = preprocessing.StandardScaler()
    data = pd.DataFrame(scaler.fit_transform(feature_cols), columns=feature_cols.columns)
    return data

def process_uci(train_df, test_df, split_train_test=False):


    train_and_test = pd.concat([train_df, test_df]).reset_index(drop=True)
    train_and_test.relationship = np.where(train_and_test.relationship=='Wife',"spouse",train_and_test.relationship)
    train_and_test.relationship = np.where(train_and_test.relationship=='Husband',"spouse",train_and_test.relationship)
    train_and_test.dropna(inplace=True)
    test_dim = test_df.shape[0]
    labels = train_and_test['target']
    labels = labels.replace('<=50K', 0).replace('>50K', 1)
    labels = labels.replace('<=50K.', 0).replace('>50K.', 1)
    if split_train_test:
        train_labels = labels[test_dim:]
        test_labels = labels[:test_dim]
        sens_attrib_train = train_and_test.sex[test_dim:]
        sens_attrib_test = train_and_test.sex[:test_dim]
    else:
        sens_attrib = train_and_test.sex

    train_and_test.drop(['education','target'], axis=1, inplace=True)
    if split_train_test:
        train = train_and_test[test_dim:]
        test = train_and_test[:test_dim]
        train = transform_uci_features(train)
        train.drop(['sex_Male','sex_Female','country_Holand-Netherlands'], axis=1, inplace=True)
        test = transform_uci_features(test)
        test.drop(['sex_Male','sex_Female'], axis=1, inplace=True)
        return (train,test), (train_labels, test_labels), (sens_attrib_train,sens_attrib_test)
    else:
        train = transform_uci_features(train_and_test)
        train.drop(['sex_Male','sex_Female','country_Holand-Netherlands'], axis=1, inplace=True)
        return train, labels, sens_attrib


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-dir", type=str)
    args = parser.parse_args()
    data_dir = Path(args.data_dir)
    train_df, test_df = load_uci(train_file=data_dir / 'adult.data',
                                 test_file=data_dir / 'adult.test')
    uci_df, income_target, gender = process_uci(train_df, test_df)

    uci_metadata_df = pd.concat([income_target, gender],axis=1)
    uci_metadata_df.columns = ['income','type']

    data_path = data_dir / 'uci_df.csv'
    print(f'Writing data to {data_path}')
    uci_df.to_csv(data_path, index=False)
    metadata_path = data_dir / 'uci_metadata_df.csv'
    print(f'Writing metadata to {metadata_path}')
    uci_metadata_df.to_csv(metadata_path, index=False)
