from argparse import Namespace
from pathlib import Path
from fairpate_tabular.utils import process_data
import pandas as pd
import numpy as np

DPDG_ADULT_COLUMN_ORDER = ['age', 'work-class', 'fnlwgt', 'education', 'education-num',
        'marital-status', 'occupation', 'relationship', 'race', 'sex',
        'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
        'income']

def get_train_test_index(seed, dataset, split, path, data_path, cols_to_norm, sensitive_attributes, undersampling_ratio, output_col_name):
    args = Namespace(dataset=dataset, 
                     output_col_name=output_col_name, 
                     split=split,
                     seed=seed,
                     path=path,
                     data_path=data_path,
                     undersampling_ratio=undersampling_ratio,
                     cols_to_norm=cols_to_norm,
                     sensitive_attributes=sensitive_attributes)
    rng = np.random.default_rng(args.seed)
    train_index, test_index = process_data(rng=rng, args=args, log=print, return_train_test_index=True)
    return train_index, test_index

def write_dp_dg_format_train_test_data(seed, dataset, split, path, data_path, save_path, cols_to_norm, sensitive_attributes, undersampling_ratio, output_col_name, column_order=None):
    train_index, test_index = get_train_test_index(seed, dataset, split, path, data_path, cols_to_norm, sensitive_attributes, undersampling_ratio, output_col_name)
    train_df = pd.read_csv(path).iloc[train_index]
    test_df = pd.read_csv(path).iloc[test_index]

    def _reformat(_df): 
        _df["income"] = _df[">50K"].apply(lambda x: ">50K" if x == "Yes" else "<=50K")
        _df["work-class"] = _df["workclass"]
        _df = _df.drop([">50K", "workclass"], axis=1)
        if column_order is not None:
            _df = _df[column_order]
        else:
            _df = _df[DPDG_ADULT_COLUMN_ORDER]
        return _df
    
    if dataset == 'adult':
        train_df = _reformat(train_df)
        test_df = _reformat(test_df)
    
    Path(save_path).mkdir(parents=True, exist_ok=True)
    train_df.to_csv(save_path + "/train.csv", index=False)
    test_df.to_csv(save_path + "/test.csv", index=False)

if __name__ == "__main__":

    write_dp_dg_format_train_test_data(
        seed=0,
        split=0.75, 
        undersampling_ratio=None, 
        
        path='./Datasets/Adult/adult_original_purified.csv', 
        data_path='./fairpate_tabular/data/',
        save_path='./baselines/dp_dg/data/adult_v1.0',
        
        dataset='adult', 
        output_col_name='>50K', 

        cols_to_norm=['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week'], 
        sensitive_attributes=['sex'],
        column_order=['age', 'work-class', 'fnlwgt', 'education', 'education-num',
        'marital-status', 'occupation', 'relationship', 'race', 'sex',
        'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
        'income']
    )
    