"""
Generating high-fidelity privacy-conscious synthetic patient data for causal effect estimation with multiple treatments
Unitility functions
"""



from matplotlib.pyplot import sca
import numpy as np
import pandas as pd
import random 
from tqdm import tqdm
from scipy.linalg import orth

def moving_average(x, w):
  return np.convolve(x, np.ones(w), 'valid') / w

# return all holidays between 2015 and 2020
def holidays():
  all_holidays = ['2015-01-01', '2015-01-19', '2015-05-25', '2015-07-03', '2015-09-07', '2015-11-26', '2015-12-25',
                  '2016-01-01', '2016-01-18', '2016-05-30', '2016-07-04', '2016-09-05', '2016-11-24', '2016-12-26',
                  '2017-01-02', '2017-01-16', '2017-05-29', '2017-07-04', '2017-09-04', '2017-11-23', '2017-12-25',
                  '2018-01-01', '2018-01-15', '2018-05-28', '2018-07-04', '2018-09-03', '2018-11-22', '2018-12-25',
                  '2019-01-01', '2019-01-21', '2019-05-27', '2019-07-04', '2019-09-07', '2019-11-28', '2019-12-25',
                  '2020-01-01', '2020-01-20', '2020-05-25', '2020-07-03', '2020-09-07', '2020-11-26', '2020-12-25']
  return all_holidays

# return all columns that need to be rounded
def col_to_round(in_data):
  is_dt_cols = [x[-3:] == '_dt' for x in in_data.columns]
  dt_cols = in_data.columns[is_dt_cols].values
  dt_cols = dt_cols.tolist()

  cols = []
  cols += ['zip_cd', 'age', 'urea_nitrogen_lab', 'urea_creatinine_ratio_lab', 'sodium_lab', 'chloride_lab', 'carbon_dioxide_lab', 'total_pop', 'household_count', 'median_income', 'mean_income', 'date+', 'last_refill']
  cols += dt_cols
  cols += ['date-']

  return cols

# put two dataframe columns side by side to compare
def compare_df(df1, df2):
  all_columns = df1.columns
  df = pd.DataFrame()
  for x in all_columns:
    df.loc[:, 'df1_' + x] = df1[x]
    df.loc[:, 'df2_' + x] = df2[x]
  return df

# make sure only one column in encoded cols is one in any row  
def enforce_to_one(df_in):

  df = df_in.copy(deep = True)
  print('enforcing to one ...')
  for i in tqdm(range(len(df))):
    if df.loc[i, :].sum() < 1.5: 
      continue
    
    one_cols = df.columns[df.loc[i, :] > 0.9].values
    random.shuffle(one_cols)
    col_to_keep = one_cols[0]

    df.loc[i, :] = 0
    df.loc[i, col_to_keep] = 1
  return df

# return all categories to drop in encoder
def encoding_categories():
  drop_categories = ['UNK', 'UNK', 'F']
  return drop_categories
 
# normalize data
def data_normalization(orig_data_arr, epsilon = 1e-8, normalization_params = None):

  if normalization_params == None:
    min_val = np.min(orig_data_arr, axis=0)
    normalized_data = orig_data_arr - min_val
    max_val = np.max(normalized_data, axis=0)
    normalized_data = normalized_data / (max_val + epsilon)
    normalization_params = {"min_val": min_val, "max_val": max_val}
  else:
    normalized_data = orig_data_arr - normalization_params['min_val']
    normalized_data = normalized_data / (normalization_params['max_val'] + epsilon)
    normalized_data[normalized_data < 0.0] = 0.0
    normalized_data[normalized_data > 1.0] = 1.0

  return normalized_data, normalization_params

# restored original scale
def data_renormalization(normalized_data, normalization_params, epsilon = 1e-8):
  renormalized_data = normalized_data * (normalization_params['max_val'] + epsilon)
  renormalized_data = renormalized_data + normalization_params['min_val']
  
  return renormalized_data

# Sample from uniform distribution
def sample_Z(m, n):
  return np.random.uniform(-1., 1., size = [m, n])

def sample_R(m, n):
  return np.random.uniform(0., 1., size = [m, n])

# Generate random samples from 0 to 1
def sample_eps(m):
  return np.random.uniform(0., 1., size = [m, 1])

# Sample from the real data
def sample_X(m, n):
  return np.random.permutation(m)[:n]  

# Xavier Initialization Definition, not that if orthogonalize the weights, the output dimension may not be size
def xavier_init(size, do_orth = False):
  in_dim = size[0]
  xavier_stddev = 1. / np.sqrt(in_dim / 2.)
  init_rand =  np.random.normal(size = size, scale = xavier_stddev)
  if do_orth: 
    init_rand = orth(init_rand) * xavier_stddev
  return np.float32(init_rand)

def select_x_columns(columms):
  selected_columns = columms

  selected_columns = selected_columns[[x[:6] != 'prior_' for x in selected_columns]]
  selected_columns = selected_columns[[x != 'zip_cd' for x in selected_columns]]

  return selected_columns

def select_covar_columns(columms):
  selected_columns = columms

  selected_columns = selected_columns[[x[:6] != 'prior_' for x in selected_columns]]
  selected_columns = selected_columns[[x[:6] != 'drugs_' for x in selected_columns]]
  selected_columns = selected_columns[[x != 'zip_cd' for x in selected_columns]]

  return selected_columns  