"""
Generating high-fidelity privacy-conscious synthetic patient data for causal effect estimation with multiple treatments

Reference: Jinsung Yoon, Lydia N. Drumright, Mihaela van der Schaar, 
"Anonymization through Data Synthesis using Generative Adversarial Networks (ADS-GAN):
A harmonizing advancement for AI in medicine," 
IEEE Journal of Biomedical and Health Informatics (JBHI), 2019.
Paper link: https://ieeexplore.ieee.org/document/9034117
Last updated Date: December 22th 2020
Code author: Jinsung Yoon (jsyoon0823@gmail.com)
-----------------------------
main_adsgan.py
- Main function for Adapted ADSGAN framework
(1) Load data
(2) Preprocess
(3) Generate synthetic data
(4) Measure the quality and identifiability of generated synthetic data
(5) Plot loss during training
"""

#%% Import necessary packages
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse

#%% Import necessary functions
from data_loader import load_company_data
from adsgan import adsgan
from metrics.feature_distribution import feature_distribution
from metrics.compute_wd import compute_wd
from metrics.compute_identifiability import compute_identifiability
from preprocess import preprocess, restore, enforce_one
import utils as ut
import pandas as pd
import pickle

PRE_PROCESS        = False
CAL_WD             = False
CAL_ID             = False
PLOT_TRAINING_FIG  = False
ENF_SYNTH_DATA_ONE = False

#%% Experiment main function
def exp_main(args):
  
  # Data loading
  raw_data = load_company_data()
  print("Finish data loading")

  # pre_process
  if PRE_PROCESS:
    raw_data_encoded, data_parms = preprocess(raw_data, 0.7)
    raw_data_encoded = enforce_one(raw_data_encoded,data_parms)
    raw_data_encoded.to_pickle('data/raw_data_encoded.pickle')
    pickle.dump(data_parms, open( "data/data_parms.pickle", 'wb' ))
    print("Finish data encoding")
  else:
    raw_data_encoded = pd.read_pickle('data/raw_data_encoded.pickle')
    data_parms = pickle.load(open('data/data_parms.pickle', 'rb' ))

  # Generate synthetic data
  params = dict()
  params["lamda"] = args.lamda
  params["iterations"] = args.iterations
  params["h_dim"] = args.h_dim
  params["z_dim"] = args.z_dim
  params["mb_size"] = args.mb_size
  
  synth_data_encoded, wd_whole_set, id_whole_set = adsgan(raw_data_encoded, params)
  if ENF_SYNTH_DATA_ONE:
    synth_data_encoded = enforce_one(synth_data_encoded, data_parms)
  result_file = 'data/syn_encoded_' + 'iter_' + str(args.iterations) + '_hdim_' + str(args.h_dim) + '_zdim_' + str(args.z_dim) + '_mbsize_' + str(args.mb_size) + '_lamda_' +  str(args.lamda) + '.pickle'
  synth_data_encoded.to_pickle(result_file)

  ## Performance measures
  # (1) Wasserstein Distance (WD)
  if CAL_WD:
    print("Start computing Wasserstein Distance")
    wd_measure = compute_wd(raw_data_encoded, synth_data_encoded, params)
    print("WD measure: " + str(wd_measure))
    
  # (2) Identifiability
  if CAL_ID:
    identifiability = compute_identifiability(raw_data_encoded, synth_data_encoded)
    print("Identifiability measure: " + str(identifiability))

  # smooth and plot
  if PLOT_TRAINING_FIG:
    import matplotlib.pyplot as plt
    mv_wd_whole_set = ut.moving_average(wd_whole_set, 30)
    mv_id_whole_set = ut.moving_average(id_whole_set, 30)
    plt.figure()
    plt.plot(mv_wd_whole_set, label = 'dist loss')
    plt.plot(mv_id_whole_set, label = 'identifiablilty')
    plt.legend()
    plt.show()

  return
  
#%%  
if __name__ == '__main__':

  # Inputs for the main function
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--iterations',
      help='number of adsgan training iterations',
      #default=7003,
      default=300,
      type=int)
  parser.add_argument(
      '--h_dim',
      help='number of hidden state dimensions',
      default=100,
      type=int)
  parser.add_argument(
      '--z_dim',
      help='number of random state dimensions',
      default=100,
      type=int)
  parser.add_argument( 
      '--mb_size',
      help='number of mini-batch samples',
      default=256,
      type=int)
  parser.add_argument(
      '--lamda',
      help='hyper-parameter to control the identifiability and quality',
      #default=1.0,
      default=0.05,
      type=float)
  
  args = parser.parse_args() 
  
  # Calls main function  
  exp_main(args)