from evaluation_census import compute_scores
import pandas as pd
import numpy as np

import ipdb

all_mean_f1 = []
all_std_f1 = []

for missing_ratio in [0.2]:
    path = './generated_samples_census2_'+str(round(missing_ratio,1))+'.csv'
    synthesized_data = pd.read_csv(path, header=None)
    synthesized_data.drop(0, axis=0, inplace=True)
    synthesized_data.replace(" <=50K", 0, inplace=True)
    synthesized_data.replace(" >50K", 1, inplace=True)
    # synthesized_data[0].astype(float)
    # ipdb.set_trace()

    path = './data_census_onehot/adult_trim.data'
    data = pd.read_csv(path, header=None)
    data.replace(" ?", np.nan, inplace=True)
    data.replace(" <=50K", 0, inplace=True)
    data.replace(" >50K", 1, inplace=True)
    seed = 1
    nfold = 5
    indlist = np.arange(len(data))

    np.random.seed(seed + 1)
    np.random.shuffle(indlist)

    tmp_ratio = 1 / nfold
    start = (int)((nfold - 1) * len(data) * tmp_ratio)
    end = start + int(len(data) * tmp_ratio)
    test_index = indlist[start:end]
    remain_index = np.delete(indlist, np.arange(start, end))

    np.random.shuffle(remain_index)
    num_train = (int)(len(remain_index) * 1)
    train_index = remain_index[:num_train]

    train = data.iloc[train_index]
    test = data.iloc[test_index]



    metadata = {'problem_type' : 'binary_classification','columns':[{
                "max": 90, 
                "min": 17, 
                "name": 0, 
                "type": "continuous"
            }, 
            {
                "i2s": [
                    " State-gov", 
                    " Self-emp-not-inc", 
                    " Private", 
                    " Federal-gov", 
                    " Local-gov", 
                    " Self-emp-inc", 
                    " Without-pay", 
                    " Never-worked", 
                    np.nan
                        ], 
                "size": 9, 
                "name": 1, 
                "type": "categorical"
            }, 
            {
                "max": 1484705, 
                "min": 12285, 
                "name": 2, 
                "type": "continuous"
            }, 
            {
                "i2s": [' Bachelors', ' HS-grad', ' 11th', ' Masters', ' 9th', ' Some-college',
 ' Assoc-acdm', ' Assoc-voc', ' 7th-8th', ' Doctorate', ' Prof-school', ' 5th-6th', ' 10th',
   ' 1st-4th', ' Preschool', ' 12th'], 
                "size": 16, 
                "name": 3, 
                "type": "categorical"
            }, 
            {
                "max": 16, 
                "min": 1, 
                "name": 4, 
                "type": "continuous"
            }, 
            {
                "i2s": [' Never-married', ' Married-civ-spouse', ' Divorced',
 ' Married-spouse-absent', ' Separated', ' Married-AF-spouse', ' Widowed',
 ' Never-mwarried'], 
                "size": 8, 
                "name": 5, 
                "type": "categorical"
            }, 
            {
                "i2s": [' Adm-clerical', ' Exec-managerial', ' Handlers-cleaners', ' Prof-specialty',
 ' Other-service', ' Sales', ' Craft-repair', ' Transport-moving',
 ' Farming-fishing', ' Machine-op-inspct', ' Tech-support', np.nan,
 ' Protective-serv', ' Armed-Forces' ,' Priv-house-serv'], 
                "size": 15, 
                "name": 6, 
                "type": "categorical"
            }, 
            {
                "i2s": [' Not-in-family', ' Husband' ,' Wife', ' Own-child', ' Unmarried',
 ' Other-relative'], 
                "size": 6, 
                "name": 7, 
                "type": "categorical"
            }, 
            {
                "i2s": [' White', ' Black', ' Asian-Pac-Islander', ' Amer-Indian-Eskimo', ' Other'], 
                "size": 5, 
                "name": 8, 
                "type": "categorical"
            }, 
            {
                "i2s": [' Male', ' Female'], 
                "size": 2, 
                "name": 9, 
                "type": "categorical"
            }, 
            {
                "max": 99999, 
                "min": 0, 
                "name": 10, 
                "type": "continuous"
            }, 
            {
                "max": 4356, 
                "min": 0, 
                "name": 11, 
                "type": "continuous"
            }, 
            {
                "max": 99, 
                "min": 1, 
                "name": 12, 
                "type": "continuous"
            }, 
            {
                "i2s": [' United-States', ' Cuba', ' Jamaica', ' India', np.nan, ' Mexico', ' South',
 ' Puerto-Rico', ' Honduras', ' England', ' Canada', ' Germany', ' Iran',
 ' Philippines', ' Italy', ' Poland', ' Columbia' ,' Cambodia' ,' Thailand',
 ' Ecuador', ' Laos' ,' Taiwan' ,' Haiti' ,' Portugal' ,' Dominican-Republic',
 ' El-Salvador', ' France', ' Guatemala' ,' China' ,' Japan', ' Yugoslavia',
 ' Peru', ' Outlying-US(Guam-USVI-etc)', ' Scotland' ,' Trinadad&Tobago',
 ' Greece', ' Nicaragua', ' Vietnam' ,' Hong' ,' Ireland', ' Hungary',
 ' Holand-Netherlands'], 
                "size": 42, 
                "name": 13, 
                "type": "categorical"
            },            
            {
                "i2s": [0,1], 
                "size": 2,  
                "name": 14, 
                "type": "categorical"
            } ]}

    mean_f1, std_f1 = compute_scores(train.to_numpy(), test.to_numpy(), synthesized_data.to_numpy(), metadata)
    all_mean_f1.append(mean_f1)
    all_std_f1.append(std_f1)

dict = dict = {"missing_ratio":0.2,"f1_mean":all_mean_f1, "f1_std":all_std_f1}
df = pd.DataFrame(dict) 
# df.to_csv('miss_diff_census_utility2.csv')
# df.to_csv('miss_diff_census_utility3_DecisionTree.csv')
# df.to_csv('miss_diff_census_utility3_AdaBoostClassifier.csv')
# df.to_csv('miss_diff_census_utility3_LogisticRegression.csv')
# df.to_csv('miss_diff_census_utility3_MLPClassifier.csv')
df.to_csv('miss_diff_census_utility3_RandomForestClassifier.csv')
print(all_mean_f1)
print(all_std_f1)