import pandas as pd
from sklearn.model_selection import train_test_split


def main():
    
    """
    dataset downloaded from:
    https://www.hindawi.com/journals/bmri/2014/781670/#supplementary-materials
    """
    
    filePath = './Data_Sets/Diabetes/'
    datafile = 'diabetic_data_initial.csv'
    
    # Read Data from csv
    df = pd.read_csv(filePath + datafile, index_col=False, skipinitialspace=True, header='infer')

    df = df.drop(
        columns=['admission_type_id', 'discharge_disposition_id', 'admission_source_id', 'encounter_id', 'patient_nbr',
                 'race', 'gender', 'weight', 'diag_1', 'diag_2', 'diag_3',
                 'medical_specialty', 'payer_code', 'change', 'A1Cresult', 'max_glu_serum', 'number_outpatient',
                 'number_emergency', 'number_inpatient'])
    low_variance_cols = ['acetohexamide', 'citoglipton', 'nateglinide', 'troglitazone', 'glimepiride-pioglitazone',
                          'metformin-rosiglitazone', 'metformin-pioglitazone', 'tolbutamide', 'tolazamide', 'examide']
    # df = df.drop(columns=low_variance_cols)

    cols = ['metformin', 'repaglinide', 'chlorpropamide', 'glimepiride',
            'insulin', 'glyburide-metformin', 'glipizide-metformin',
            'glipizide', 'glyburide', 'pioglitazone', 'rosiglitazone', 'acarbose', 'miglitol']
    
    cols = cols + low_variance_cols

    # simplify categories
    cat_map = {'No': 0, 'Steady': 1, 'Down': 2, 'Up': 3}
    for col in cols:
        df[col] = df[col].map(cat_map)

    age_map = {"[0-10)": 0, "[10-20)": 1, "[20-30)": 2, "[30-40)": 3, "[40-50)": 4,
               "[50-60)": 5, "[60-70)": 6, "[70-80)": 7, "[80-90)": 8, "[90-100)": 9, "[100-110)": 10}
    dia_map = {'No': 0, 'Yes': 1}
    class_map = {'NO': 1, '>30': 0, '<30': 0}
    df['age'] = df['age'].map(age_map)
    df['diabetesMed'] = df['diabetesMed'].map(dia_map)
    df['readmitted'] = df['readmitted'].map(class_map)

    # convert all '?' to NaN and drop those rows
    df = df[df != '?']
    df.dropna(axis=0, inplace=True)
    # subsample 25 percent of all data
    df = df.sample(frac=0.15, replace=False, random_state=1)
    non_label_columns = list(df.columns)
    non_label_columns.remove('readmitted')
    #print(non_label_columns)
    df.drop_duplicates(subset=non_label_columns, inplace=True)
    X_train, X_test = train_test_split(df, stratify=df['readmitted'], test_size=0.20, random_state=0)
    X_train.to_csv('./Data_Sets/Diabetes/diabetes-train.csv', index=False)
    X_test.to_csv('./Data_Sets/Diabetes/diabetes-test.csv', index=False)


if __name__ == "__main__":
    main()
