import numpy as np
import pandas as pd
import category_encoders as ce
import pickle

path = './data_census_onehot/adult_trim.data'
data = pd.read_csv(path, header=None)
data.replace(" ?", np.nan, inplace=True)

cat_list = [1, 3, 5, 6, 7, 8, 9, 13, 14]

 # Swap columns
temp_list = [i for i in range(data.shape[1]) if i not in cat_list]
temp_list.extend(cat_list)
new_cols_order = temp_list
data = data.reindex(columns=data.columns[new_cols_order])
data.columns = [i for i in range(data.shape[1])]

cont_list = [i for i in range(0, data.shape[1] - len(cat_list))]
cat_list = [i for i in range(len(cont_list), data.shape[1])]

encoder = ce.one_hot.OneHotEncoder(cols=data.columns[cat_list])
encoder.fit(data)
trans_data = encoder.transform(data)


# new_observed_values = np.nan_to_num(new_observed_values)

col_num = len(cont_list)
data_cont = trans_data.to_numpy()[:,:col_num]
max_arr = np.nanmax(data_cont, axis=0)
min_arr = np.nanmin(data_cont, axis=0)



new_data=np.load('./generated_samples_census2_0.2.npy')

with open("./data_census_onehot/transformed_columns.pk", "rb") as f:
    cont_cols, saved_cat_dict = pickle.load(f)

index= [saved_cat_dict[key][0] for i, key in enumerate(saved_cat_dict)]

# new_data[:,index[0]:] = np.round(new_data[:,index[0]:])
# new_data_cat = new_data[:,index[0]:]
# ind = new_data_cat == -1
# new_data_cat[ind] = 0
# new_data[:,index[0]:] = new_data_cat

index.append(new_data.shape[1])
import torch
new_data = torch.from_numpy(new_data).float() 
for i in range(len(index)-1):
    new_data[:,index[i]:index[i+1]] = torch.nn.functional.softmax(new_data[:,index[i]:index[i+1]],1)
    aaa = torch.argmax(new_data[:,index[i]:index[i+1]], dim=1)
    bbb = torch.zeros(aaa.shape[0], index[i+1]-index[i])
    bbb[torch.arange(aaa.shape[0]), aaa] = 1
    new_data[:,index[i]:index[i+1]] = bbb.int()


excel_data = encoder.inverse_transform(new_data.numpy())
excel_data = excel_data.to_numpy()
excel_data[:,:col_num] = excel_data[:,:col_num] * (max_arr - min_arr) + min_arr
excel_data[:,:col_num] = np.round(excel_data[:,:col_num].astype(float))

excel_data = pd.DataFrame(excel_data)

re_temp_list = [0,6,1,7,2,8,9,10,11,12,3,4,5,13,14]

excel_data = excel_data.reindex(columns=excel_data.columns[re_temp_list])
excel_data.columns = [i for i in range(excel_data.shape[1])]



excel_data.to_csv('./generated_samples_census2_0.2.csv', index=False, header=['age', 'workclass', 'fnlwgt', 'education', 'education-num',
       'marital-status', 'occupation', 'relationship', 'race', 'sex',
       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
       'label'])


