import numpy as np
import pandas as pd
import category_encoders as ce
import pickle
import ipdb

def reverse_transformation_bayesian(missing_rate=""):
    # path = './data_MIMIC4ED_onehot/train.csv'
    # data = pd.read_csv(path)

    # data = data.iloc[:, [7, 18, 26, 27, 28, 32, 33, 34, 35, 36, 37, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 78, 79, 80, 82, 83, 85, 89, 94, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 20, 115, 116]]
    # data['ed_los'] = pd.to_timedelta(data['ed_los']).dt.seconds / 60
    # data.to_csv('./data_MIMIC4ED_onehot/used.csv')
    
    path = './data_MIMIC4ED_onehot/used.csv'
    data = pd.read_csv(path)
    data.drop(columns='Unnamed: 0', axis=1,inplace=True)

    cat_list = [0]
    cont_list = [i for i in range(1, data.shape[1])]

    encoder = ce.one_hot.OneHotEncoder(cols=data.columns[cat_list])
    encoder.fit(data)
    trans_data = encoder.transform(data)
    
    # discrete_index = [0,1,2,3,4,5,6,7,8,9,10,11,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64]
    # continuous_index=list(set(range(74)) - set(discrete_index))

    # data_cont = trans_data.to_numpy()[:,continuous_index]#.astype(float)
    # max_arr = np.nanmax(data_cont, axis=0)
    # min_arr = np.nanmin(data_cont, axis=0)
    
    # discrete_index = [0,1,2,3,4,5,6,7,8,9,10,11,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64]
    discrete_index = [0,1]
    continuous_index=list(set(range(74)) - set(discrete_index))

    data_cont = trans_data.to_numpy()[:,continuous_index]#.astype(float)
    max_arr = np.nanmax(data_cont, axis=0)
    min_arr = np.nanmin(data_cont, axis=0)

    # ipdb.set_trace()

    # new_data=np.load('./generated_samples_MIMIC4ED_'+missing_rate+'.npy')
    new_data=np.load('./generated_MIMIC4ED/generated_samples_MIMIC4ED2_'+missing_rate+'.npy')

    import torch
    new_data = torch.from_numpy(new_data).float() 

    new_data[:,:2] = torch.nn.functional.softmax(new_data[:,:2],1)
    aaa = torch.argmax(new_data[:,:2], dim=1)
    bbb = torch.zeros(aaa.shape[0], aaa.max() + 1)
    bbb[torch.arange(aaa.shape[0]), aaa] = 1
    new_data[:,:2] = bbb.int()

    col_num = [2,3,4,5,6,7,8,9,10,11,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64]
    new_data[:,col_num] = new_data[:,col_num].round()


    excel_data = encoder.inverse_transform(new_data.numpy())
    excel_data = excel_data.to_numpy()

    # ipdb.set_trace()
    #最后一列忘记normalization了
    # continuous_index = [11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,64,65,66,67,68,69,70,71]#,72
    continuous_index = range(1,72)
    excel_data[:,continuous_index] = excel_data[:,continuous_index] * (max_arr[:-1] - min_arr[:-1]) + min_arr[:-1]
    # excel_data[:,:col_num] = np.round(excel_data[:,:col_num].astype(float))

    int_list = [12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,65,66,67,68,69,70,71,72]
    excel_data[:,int_list] = excel_data[:,int_list].astype('float').round()

    # ipdb.set_trace()

    excel_data = pd.DataFrame(excel_data)

    # re_temp_list = [0,1,2,3,4]

    # excel_data = excel_data.reindex(columns=excel_data.columns[re_temp_list])
    # excel_data.columns = [i for i in range(excel_data.shape[1])]



    # excel_data.to_csv('./generated_samples_MIMIC4ED_'+missing_rate+'.csv', index=False, header=['gender', 'age', 'n_ed_30d', 'n_ed_90d', 'n_ed_365d', 'n_hosp_30d',
    excel_data.to_csv('./generated_MIMIC4ED/generated_samples_MIMIC4ED2_'+missing_rate+'.csv', index=False, header=['gender', 'age', 'n_ed_30d', 'n_ed_90d', 'n_ed_365d', 'n_hosp_30d',
       'n_hosp_90d', 'n_hosp_365d', 'n_icu_30d', 'n_icu_90d',
       'n_icu_365d', 'triage_temperature', 'triage_heartrate',
       'triage_resprate', 'triage_o2sat', 'triage_sbp', 'triage_dbp',
       'triage_pain', 'triage_acuity', 'chiefcom_chest_pain',
       'chiefcom_abdominal_pain', 'chiefcom_headache',
       'chiefcom_shortness_of_breath', 'chiefcom_back_pain',
       'chiefcom_cough', 'chiefcom_nausea_vomiting',
       'chiefcom_fever_chills', 'chiefcom_syncope', 'chiefcom_dizziness',
       'cci_MI', 'cci_CHF', 'cci_PVD', 'cci_Stroke', 'cci_Dementia',
       'cci_Pulmonary', 'cci_Rheumatic', 'cci_PUD', 'cci_Liver1',
       'cci_DM1', 'cci_DM2', 'cci_Paralysis', 'cci_Renal', 'cci_Cancer1',
       'cci_Liver2', 'cci_Cancer2', 'cci_HIV', 'eci_Arrhythmia',
       'eci_Valvular', 'eci_PHTN', 'eci_HTN1', 'eci_HTN2',
       'eci_NeuroOther', 'eci_Hypothyroid', 'eci_Lymphoma',
       'eci_Coagulopathy', 'eci_Obesity', 'eci_WeightLoss',
       'eci_FluidsLytes', 'eci_BloodLoss', 'eci_Anemia', 'eci_Alcohol',
       'eci_Drugs', 'eci_Psychoses', 'eci_Depression',
       'ed_temperature_last', 'ed_heartrate_last', 'ed_resprate_last',
       'ed_o2sat_last', 'ed_sbp_last', 'ed_dbp_last', 'ed_los', 'n_med',
       'n_medrecon'])

if __name__ == "__main__":
    reverse_transformation_bayesian(missing_rate="0.2")
