"""
Generating high-fidelity privacy-conscious synthetic patient data for causal effect estimation with multiple treatments
Preprocess data: clearn, one-hot encode, standardize, and transform data
"""


from struct import unpack
import numpy as np
import pandas as pd
import random
from pandas.core.indexes import category
import utils as ut
from pandas.tseries.offsets import BDay
from sklearn.preprocessing import OneHotEncoder
from tqdm import tqdm

def preprocess(raw_data, drug_pct_cutoff = 0.7):
  """
    Pre-process, transform, clean up, and encode data
    Returns:
    orig_data: cleaned and transformed data in pandas dataframe
    parms: parameters that can be used to restore data
  """

  orig_data = raw_data.copy(deep = True)

  parms = {}
  nan_cols = []

  orig_data.replace('', np.nan, inplace = True)

  ################################################
  # encode drugs and prior_drugs to different columns, assuming 'drugs' and 'prior_drugs' are lists
  if 'drugs' in orig_data.columns and 'prior_drugs' in orig_data.columns:
    # find out drug sets: 
    # remove drugs that do not appear frequently 
    # there are 324 unique drugs and 329 unique drugs+prior_drugs
    all_drugs_set = set()
    for x in orig_data['drugs']:
      all_drugs_set = all_drugs_set | set(x)
    for x in orig_data['prior_drugs']:
      all_drugs_set = all_drugs_set | set(x)

    original_len = orig_data.shape[0]

    # remove rows that have drug combinations (more than one drug)
    orig_data.loc[:, 'keep_drugs'] = orig_data['drugs'].apply(lambda x: len(x) > 0.9 and len(x) < 1.1)
    orig_data = orig_data.loc[orig_data['keep_drugs'],:]
    orig_data.loc[:, 'keep_drugs'] = orig_data['prior_drugs'].apply(lambda x: len(x) > 0.9 and len(x) < 1.1)
    orig_data = orig_data.loc[orig_data['keep_drugs'],:]

    # loop through drug columns and count the number of appearances of each drug
    drug_freq = pd.DataFrame.from_dict(dict(zip(all_drugs_set, np.zeros(len(all_drugs_set)))), orient='index')

    print("counting drugs")
    for x in tqdm(orig_data['drugs']):
      x_list = list(x)
      row_drugs = pd.DataFrame.from_dict(dict(zip(x_list, np.ones(len(x_list)))), orient='index')
      drug_freq = drug_freq.add(row_drugs, axis = 0, fill_value = 0)
    
    # rank drug counts and calculate cumulative pct
    drug_freq.sort_values(by = 0, ascending = False, inplace = True)
    drug_freq.loc[:, 'pct'] = drug_freq[0].cumsum() / drug_freq[0].sum()

    # keep drugs that appear most frequently, defined by drug_pct_cutoff
    drug_freq = drug_freq.loc[drug_freq.pct < drug_pct_cutoff,:]
    drugs_inclusion_list = list(drug_freq.index)

    # if the drug is in the keep list, keep it, otherwise discard it
    # at this point, there should be only one drug in each row
    orig_data.loc[:, 'keep_drugs'] = orig_data['drugs'].apply(lambda x: x[0] in drugs_inclusion_list)
    orig_data = orig_data.loc[orig_data['keep_drugs'],:]
    orig_data.loc[:, 'keep_drugs'] = orig_data['prior_drugs'].apply(lambda x: x[0] in drugs_inclusion_list)
    orig_data = orig_data.loc[orig_data['keep_drugs'],:]
    orig_data.drop(columns = ['keep_drugs'], inplace = True)
    new_len = orig_data.shape[0]
    print("keep {:.1f}% of rows due to drug selections".format(100 * new_len / (original_len + 0.001)))
    orig_data = orig_data.reset_index(drop=True)

    # encode drugs and prior drugs into different collumns, then drop the 'drugs' and 'prior_drugs' columns
    orig_data.loc[:, ['drugs_' + x for x in drugs_inclusion_list]] = 0
    orig_data.loc[:, ['prior_' + x for x in drugs_inclusion_list]] = 0
    print("encoding drugs")
    for i in tqdm(range(len(orig_data))):
      drugs_cols = ['drugs_' + x for x in orig_data.loc[i, 'drugs']]
      prior_drugs_cols = ['prior_' + x for x in orig_data.loc[i, 'prior_drugs']]
      orig_data.loc[i, drugs_cols] = 1
      orig_data.loc[i, prior_drugs_cols] = 1
    orig_data.drop(columns = ['drugs'], inplace = True)
    orig_data.drop(columns = ['prior_drugs'], inplace = True)

    # save all the drugs included for inverse transformation
    parms['all_drugs'] = drugs_inclusion_list
    print("Finished drugs")

  ################################################ 
  # encode morbs_prior keys to different columns, assuming 'morbs_prior' are dictionariy strings
  if 'morbs_prior' in orig_data.columns:
    orig_data.loc[:, 'morbs_prior_keys'] = orig_data['morbs_prior'].apply(lambda x: set(eval(x).keys()))
    print("finding all set ...")
    all_morbs_set = set()
    for x in tqdm(orig_data['morbs_prior_keys']):
      all_morbs_set = all_morbs_set | x
    print("number of prior morbs: {:3d}".format(len(all_morbs_set)))
    all_morbs_list = list(all_morbs_set)
    orig_data.loc[:, all_morbs_list] = 0

    print("encoding prior morbs ...")
    for col in tqdm(all_morbs_list):
      orig_data.loc[:,col] = orig_data[[col, 'morbs_prior_keys']].apply(lambda x: 1 if col in x['morbs_prior_keys'] else 0, axis = 1)

    parms['all_morbs_prior'] = all_morbs_list

    orig_data.drop(columns = ['morbs_prior_keys'], inplace = True)
    orig_data.drop(columns = ['morbs_prior'], inplace = True)
    print("Finished morbs_prior")

  ################################################
  # encode safety_morbs to different columns, assuming 'safety_morbs' are dictionary strings
  if 'safety_morbs' in orig_data.columns:
    all_safety_morbs_list = list(eval(orig_data['safety_morbs'][0]).keys())
    orig_data.loc[:, all_safety_morbs_list] = 0
    has_value = ~(pd.isnull(orig_data['safety_morbs']) | (orig_data['safety_morbs'] == ''))
    print("encoding safety morbs")
    for col_name in tqdm(all_safety_morbs_list):
      orig_data.loc[has_value, col_name] = orig_data['safety_morbs'][has_value].apply(lambda x: eval(x)[col_name])
    parms['all_safety_morbs'] = all_safety_morbs_list

    orig_data.drop(columns = ['safety_morbs'], inplace = True)
    print("Finished safety_prior")
  
  ################################################
  # process all date related columns
  NOT_NAN_CUTOFF = 10
  if 'date-' in orig_data.columns:
    # keep the latest date in last_refill column, assuming 'last_refill' is a list of dates
    orig_data.loc[:, 'last_refill'] = orig_data['last_refill'].apply(max)

    # choose lab date cols that have enough number of values (>= NOT_NAN_CUTOFF), drop those that don't and all related lab columns
    lab_date_cols_temp =  list(filter(lambda x: x[-3:] == '_dt', orig_data.columns))
    lab_date_cols = []
    for col in lab_date_cols_temp:
      has_value = ~(pd.isnull(orig_data['date-']) | pd.isnull(orig_data[col]) | (orig_data['date-'] == '') | (orig_data[col] == ''))
      if has_value.sum() < NOT_NAN_CUTOFF:
        nan_cols.append(col)
        orig_data.drop(columns = [col], inplace = True)
        related_lab_col = col[:-2] + 'lab'
        nan_cols.append(related_lab_col)
        orig_data.drop(columns = [related_lab_col], inplace = True)
      else:
        lab_date_cols.append(col)
    
    # create indicator columns for lab dates and lab values
    lab_date_indicators = []
    for col in lab_date_cols:
      ind_col_name = col.replace('_dt', '_indicator')
      orig_data.loc[:, ind_col_name] = 1      
      orig_data.loc[orig_data[col].isnull(), ind_col_name] = 0
      lab_date_indicators.append(ind_col_name)
    parms['lab_date_ind_cols'] = lab_date_indicators

    # convert all _dt columns, 'date+' columns to days since 'date-', excluding holidays and weekend
    # convert 'last_refill' colunn to days since 'date-', including holidays and weekend
    # then drop original date columns
    date_cols = ['date+'] + ['last_refill'] + lab_date_cols
    all_holidays = ut.holidays()
    print("converting all date columns to days ...")
    for col in tqdm(date_cols):
      has_value = ~(pd.isnull(orig_data['date-']) | pd.isnull(orig_data[col]) | (orig_data['date-'] == '') | (orig_data[col] == ''))
      orig_data.loc[:, col + '_diff'] = np.nan
      if col == 'last_refill':
        orig_data.loc[has_value, col + '_diff'] = ((orig_data[col][has_value] - orig_data['date-'][has_value])).apply(lambda x: x.days)
      else:
        orig_data.loc[has_value, col + '_diff'] = np.busday_count(orig_data['date-'][has_value].apply(lambda x:  x.strftime('%Y-%m-%d')), 
                                                              orig_data[col][has_value].apply(lambda x:  x.strftime('%Y-%m-%d')), 
                                                              holidays = all_holidays)
      orig_data.drop(columns = col, inplace = True)

    # convert 'date-' column to days since the reference date, excluding weekend and holidays, then drop 'date-' column
    date_reference = '2014-01-01'
    orig_data.loc[:, 'date-_diff'] = np.nan
    has_value = ~(pd.isnull(orig_data['date-']) | (orig_data['date-'] == ''))
    orig_data.loc[has_value, 'date-_diff'] = np.busday_count(date_reference,
                                                          orig_data['date-'][has_value].apply(lambda x:  np.nan if pd.isnull(x) else x.strftime('%Y-%m-%d')),
                                                          holidays = all_holidays)
    orig_data.drop(columns = 'date-', inplace = True)

    # change all dates_diff columns back to dates
    all_date_cols = ['date-'] + date_cols
    all_date_cols_diff = [x + '_diff' for x in all_date_cols]
    orig_data.rename(columns = dict(zip(all_date_cols_diff, all_date_cols)), inplace = True)

    # replace all nan values with mean in all date and lab related columns
    all_lab_cols = [x.replace('_dt', '_lab') for x in lab_date_cols]
    replace_date_values = orig_data[all_date_cols].mean().astype(int)
    replace_lab_values = orig_data[all_lab_cols].mean()

    parms['date_reference'] = date_reference
    parms['replace_values'] = pd.concat([replace_date_values, replace_lab_values], axis = 0)
    parms['all_date_cols'] = all_date_cols
    for x in all_date_cols + all_lab_cols:
      orig_data[x].fillna(parms['replace_values'][x], inplace = True)

    print("Finished dates")

  ################################################
  # one hot encode columns 'race_cd', 'ethncty_cd', 'gndr_cd' 
  if 'race_cd' in orig_data.columns:
    original_categorical_cols = ['race_cd', 'ethncty_cd', 'gndr_cd']
    for x in original_categorical_cols:
      orig_data.loc[:, x].fillna('UNK', inplace = True)
      orig_data.loc[:, x].replace('NA',  'UNK', inplace = True)
      orig_data.loc[:, x].replace('~01', 'UNK', inplace = True)

    # get categories to drop, to avoid colinearity
    drop_categories = ut.encoding_categories()

    enc = OneHotEncoder(drop = drop_categories, handle_unknown='error', sparse = False)
    encoded_data = enc.fit_transform(orig_data[original_categorical_cols])

    all_categories_butdropped_for_each_col = enc.categories_.copy()
    for i in range(len(original_categorical_cols)):
      all_categories_butdropped_for_each_col[i] = np.delete(all_categories_butdropped_for_each_col[i], enc.drop_idx_[i])

    all_encoded_cols_flat = []
    for i in range(len(original_categorical_cols)):
      all_encoded_cols_flat += [original_categorical_cols[i] + "_" + x for x in all_categories_butdropped_for_each_col[i]]

    orig_data.loc[:, all_encoded_cols_flat] = encoded_data
    parms['original_categorical_cols'] = original_categorical_cols
    parms['leveled_encoded_cols'] = all_categories_butdropped_for_each_col
    parms['flat_encoded_cols'] = all_encoded_cols_flat
    parms['encoder'] = enc
    orig_data.drop(columns = original_categorical_cols, inplace = True)
    print("Finished race_cd")


  # other clean ups in entire dataframe, remove trajectory, mcid columns, replace '-' values etc
  orig_data.drop(columns = ['trajectory_index', 'mcid'], inplace = True)
  orig_data.replace('UNK', np.nan, inplace = True)
  orig_data.replace('', np.nan, inplace = True)
  orig_data.replace('-', np.nan, inplace = True)
  orig_data.dropna(axis=0, how='any', inplace = True)
  orig_data = orig_data.reset_index(drop=True)

  if('p_male' in orig_data.columns):
    orig_data.drop(columns = ['p_male'], inplace = True)

  # format some columns into proper precision
  if('zip_cd' in orig_data.columns):
    orig_data.loc[:, 'zip_cd'] = orig_data['zip_cd'].apply(lambda x: int(x))
  if('p_female' in orig_data.columns):
    orig_data.loc[:, 'p_female'] = orig_data['p_female'].apply(lambda x: float(x))
  if('median_age' in orig_data.columns):
    orig_data.loc[:, 'median_age'] = orig_data['median_age'].apply(lambda x: float(x))
  if('median_income' in orig_data.columns):
    orig_data.loc[:, 'median_income'] = orig_data['median_income'].apply(lambda x: float(x.replace('+','').replace(',','')))
  if('mean_income' in orig_data.columns):
    orig_data.loc[:, 'mean_income'] = orig_data['mean_income'].apply(lambda x: float(x))

  # convert string types into values
  str_cols = []
  for col in orig_data.columns:
    if isinstance(orig_data[col][0], str):
      str_cols.append(col)
      orig_data.loc[:, col] = orig_data[col].apply(lambda x: float(x))

  parms['str_cols'] = str_cols
  parms['nan_cols'] = nan_cols
  return orig_data, parms

def restore(proc_data, parms):
  """
    inverse transformation, restore original annotation table format
    Returns:
    df: restored annotation table
  """

  rest_data = proc_data.copy(deep = True)

  # pack morbs_prior into one column, then drop all encoded morbs_prior columns
  if ('all_morbs_prior' in parms.keys()):
    rest_data.loc[:, 'morbs_prior'] = rest_data[parms['all_morbs_prior']].apply(lambda x: str((x[x==1]).to_dict()), axis = 1)
    rest_data.drop(columns = parms['all_morbs_prior'], inplace = True)

  # park safety_morbs into one column, then drop all encoded columns
  if ('all_safety_morbs' in parms.keys()):
    rest_data.loc[:, 'safety_morbs'] = rest_data[parms['all_safety_morbs']].apply(lambda x: str(x.to_dict()).replace(' ', ''), axis = 1)
    rest_data.drop(columns = parms['all_safety_morbs'], inplace = True)

  # pack drugs and prior drugs into one column, then drop all encoded columns
  if ('all_drugs' in parms.keys()):
    for name in ['drugs', 'prior']:
      all_drugs = [name + '_'+ x for x in parms['all_drugs']]
      col_name = name if name == 'drugs' else name + '_drugs'
      rest_data.loc[:, col_name] = rest_data[all_drugs].apply(lambda x: list(x[x==1].index), axis = 1)
      rest_data.loc[:, col_name] = rest_data[col_name].apply(lambda x: [y[6:] for y in x])
      rest_data.drop(columns = all_drugs, inplace = True)

  # convert all date columns from number of days to actual dates
  if('date_reference' in parms.keys()):
    # restore date-
    all_holidays = ut.holidays()
    has_value = (rest_data['date-'] != '')
    rest_data.loc[has_value, 'date-'] = np.busday_offset(parms['date_reference'], rest_data['date-'][has_value], holidays = all_holidays)

    for col in parms['all_date_cols']:
      if col != 'date-':
        has_value = ~((rest_data['date-'] == '') | (rest_data[col] == ''))
        # 'last_refill' and other date columns are dealt differently because of holidays and weekends, assuming last_refill can happend on any day
        if col == 'last_refill':
          rest_data.loc[has_value, col] = rest_data['date-'][has_value] + rest_data[col][has_value].apply(lambda x: pd.Timedelta(x, unit = 'd'))
          rest_data.loc[has_value, col] = rest_data[col][has_value].apply(lambda x: [x])
        else:
          rest_data.loc[has_value, col] = np.busday_offset(rest_data['date-'][has_value].apply(lambda x: x.strftime('%Y-%m-%d')), rest_data[col][has_value], holidays = all_holidays)
        rest_data.loc[~has_value, col] = ''
        
  # restore lab nan values based on the indicator column
  if('lab_date_ind_cols' in parms.keys()):
    for col in parms['lab_date_ind_cols']:
      lab_dt_col = col.replace('_indicator', '_dt')
      lab_col = col.replace('_indicator', '_lab')
      no_values = (rest_data[col] == 0)
      rest_data.loc[no_values, lab_dt_col] = ''
      rest_data.loc[no_values, lab_col] = ''
      rest_data.drop(columns = [col], inplace = True)

  # inverse transform using encoder
  if('encoder' in parms.keys()):
    temp = parms['encoder'].inverse_transform(rest_data[parms['flat_encoded_cols']])
    restored_categorical_cols = parms['original_categorical_cols']
    rest_data.loc[:, restored_categorical_cols] = temp
    rest_data.drop(columns = parms['flat_encoded_cols'], inplace = True)

  # add back female column
  if ('p_female' in rest_data.columns):
    rest_data.loc[:, 'p_male'] = 100 - rest_data['p_female']

  # add back all the nan columns
  if len(parms['nan_cols']) > 0:
    rest_data.loc[:, parms['nan_cols']] = ''

  # add back the mcid and trajectory_index columns
  mcid_list = list(range(rest_data.shape[0]))
  random.shuffle(mcid_list)
  rest_data.loc[:, 'mcid'] = [1000000 + x for x in mcid_list]
  rest_data.loc[:, 'trajectory_index'] = rest_data[['mcid', 'date-']].apply(lambda x: str(x['mcid']) + str(x['date-']).replace(' ', ''), axis = 1)

  # properly format some columns
  if('zip_cd' in rest_data.columns):
    rest_data.loc[:, 'zip_cd'] = rest_data['zip_cd'].apply(lambda x: '{:05d}'.format(int(x)))
  if('p_male' in rest_data.columns):
    rest_data.loc[:, 'p_male'] = rest_data['p_male'].apply(lambda x: '{:.2f}'.format(x))
  if('p_female' in rest_data.columns):
    rest_data.loc[:, 'p_female'] = rest_data['p_female'].apply(lambda x: '{:.2f}'.format(x))
  if('median_age' in rest_data.columns):
    rest_data.loc[:, 'median_age'] = rest_data['median_age'].apply(lambda x: '{:.1f}'.format(x))
  if('median_income' in rest_data.columns):
    rest_data.loc[:, 'median_income'] = rest_data['median_income'].apply(lambda x: '{:.0f}'.format(x))
  if('mean_income' in rest_data.columns):
    rest_data.loc[:, 'mean_income'] = rest_data['mean_income'].apply(lambda x: '{:.0f}'.format(x))

  return rest_data

def enforce_one(df, parms):
  """
    if encoded columns have more than one column with one, just random pick a columns with one and make other colums zero 
    Returns:
    df: processed dataframe
  """

  # process encoded categorical columns
  print('enforcing to one for encoder')
  if('encoder' in parms.keys()):
    for i in range(len(parms['leveled_encoded_cols'])):
      encoded_cols_each_level = [parms['original_categorical_cols'][i] + "_" + x for x in parms['leveled_encoded_cols'][i]]
      df.loc[:, encoded_cols_each_level] = ut.enforce_to_one(df[encoded_cols_each_level])

  # process drugs and prior_drugs columns
  print('enforcing to one for all_drugs')
  if ('all_drugs' in parms.keys()):
    for name in ['drugs', 'prior']:
      all_drugs = [name + '_'+ x for x in parms['all_drugs']]
      df.loc[:, all_drugs] = ut.enforce_to_one(df[all_drugs])

  return df



