"""

-----------------------------
data_loader.py
- data loading function for ADSGAN framework
(1) Load data and return pandas dataframe
"""

# Import necessary packages
import pandas as pd
import numpy as np

def load_maggic_data():
  """Load MAGGIC data.
  
  Returns:
    orig_data: Original data in pandas dataframe
  """
  # Read csv files
  file_name = 'data/Maggic.csv'
  orig_data = pd.read_csv(file_name, sep=',')

  # Remove NA
  orig_data = orig_data.dropna(axis=0, how='any')
        
  # Remove labels
  orig_data = orig_data.drop(['death_all','days_to_fu'], axis = 1)
  
  return orig_data


def load_random_data():
  """Generate and return random data.
  
  Returns:
    orig_data: Original data in pandas dataframe
  """

  orig_data = pd.DataFrame(np.random.randint(0,100,size=(128*5, 10)), columns=list('ABCDEFGHIJ'))
  
  return orig_data


def load_company_data():
  """Load company data.
  
  Returns:
    orig_data: Original data in pandas dataframe
  """
  # Read parquet file
  #input_file_path = './data/annotation_bp_20210301_sdoh_small.parquet'
  input_file_path = 'data/annotation_bp_20210301_sdoh.parquet'

  orig_data = pd.read_parquet(input_file_path)
  orig_data = orig_data.reset_index(drop=True)

  # all columns:
  # trajectory_index', 'mcid', 'date-', 'drugs', 'last_refill', 'lab-', 'lab+', 'date+', 'prior_drugs', 'gndr_cd', 'race_cd', 'ethncty_cd', 'zip_cd', 'age', 'diastolic_lab', 'diastolic_dt', 'egfr_lab', 'egfr_dt',
  # 'creatinine_lab', 'creatinine_dt', 'potassium_lab', 'potassium_dt', 'glucose_lab', 'glucose_dt', 'urea_nitrogen_lab', 'urea_nitrogen_dt', 'urea_creatinine_ratio_lab', 'urea_creatinine_ratio_dt', 'sodium_lab',
  # 'sodium_dt', 'chloride_lab', 'chloride_dt', 'carbon_dioxide_lab', 'carbon_dioxide_dt', 'calcium_lab', 'calcium_dt', 'phosphorus_lab', 'phosphorus_dt', 'albumin_lab', 'albumin_dt', 'bicarbonate_lab',
  # 'bicarbonate_dt', 'anion_gap_lab', 'anion_gap_dt', 'safety_morbs', 'morbs_prior', 'total_pop', 'p_male', 'p_female', 'median_age', 'p_white', 'p_black_or_aframr', 'p_amrind_and_alskntv', 'p_asian',
  # 'p_ntvhwi_and_othrpcfisl', 'p_othr', 'p_multiracial', 'household_count', 'median_income', 'mean_income'

  #columns_to_keep = ['lab-', 'lab+', 'age', 'diastolic_lab', 'total_pop', 'p_male', 'p_female', 'household_count']
  #columns_to_keep = ['lab-', 'lab+', 'age', 'diastolic_lab', 'total_pop', 'household_count', 'drugs', 'prior_drugs']
  #columns_to_keep = ['date-', 'date+', 'age', 'diastolic_dt', 'egfr_dt', 'creatinine_dt', 'potassium_dt', 'glucose_dt', 'urea_nitrogen_dt', 'urea_creatinine_ratio_dt', 'sodium_dt', 
  #'chloride_dt', 'carbon_dioxide_dt', 'calcium_dt', 'phosphorus_dt', 'albumin_dt', 'bicarbonate_dt', 'anion_gap_dt']
  #columns_to_keep = ['date-', 'date+', 'last_refill', 'egfr_dt', 'egfr_lab', 'sodium_dt', 'sodium_lab']
  #columns_to_keep = ['race_cd', 'ethncty_cd', 'gndr_cd']

  #orig_data = orig_data[columns_to_keep]

  return orig_data