"""
Generating high-fidelity privacy-conscious synthetic patient data for causal effect estimation with multiple treatments


Reference: Jinsung Yoon, Lydia N. Drumright, Mihaela van der Schaar, 
"Anonymization through Data Synthesis using Generative Adversarial Networks (ADS-GAN):
A harmonizing advancement for AI in medicine," 
IEEE Journal of Biomedical and Health Informatics (JBHI), 2019.
Paper link: https://ieeexplore.ieee.org/document/9034117
Last updated Date: December 22th 2020
Code author: Jinsung Yoon (jsyoon0823@gmail.com)
-----------------------------
adsgan.py
- Generate synthetic data using adapted ADSGAN framework
"""

#%% Import necessary packages
import tensorflow as tf
import numpy as np
import pandas as pd

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from metrics.compute_identifiability import compute_identifiability
import utils as ut

def adsgan(orig_data_df, params):
  """Generate synthetic data for ADSGAN framework.
  
  Args:
    orig_data_df: original data
    params: Network parameters
      mb_size: mini-batch size
      z_dim: random state dimension
      h_dim: hidden state dimension
      lamda: identifiability parameter
      iterations: training iterations
      
  Returns:
    synth_data: synthetically generated data, distribution loss and identifiability loss during training
  """
    
  # Reset the tensorflow graph
  tf.compat.v1.reset_default_graph()
  
  ## Parameters    
  # Feature no
  x_dim = len(orig_data_df.columns)    
  
  # Batch size    
  mb_size = params['mb_size']
  # Random variable dimension
  z_dim = params['z_dim'] 
  # Hidden unit dimensions
  h_dim = params['h_dim']    
  # Identifiability parameter
  lamda = params['lamda']
  # Training iterations
  iterations = params['iterations']
  
  # WGAN-GP parameters
  lam = 10
  lr = 1e-4    
  randomness = 0.
  beta = 0.01

  #%% Data Preprocessing
  gen_data_df = orig_data_df.copy()
  orig_data_arr = np.asarray(orig_data_df)

  orig_data_arr, normalization_params = ut.data_normalization(orig_data_arr)
  orig_train, orig_test = train_test_split(orig_data_arr, test_size=0.33)
  # Sample no
  no_orig  = len(orig_data_arr)
  no_train = len(orig_train)
  no_test  = len(orig_test)

  #with tf.device("/gpu:0"):
  #%% Placeholder
  # Feature
  X = tf.compat.v1.placeholder(tf.float32, shape = [None, x_dim])
  # Random Variable    
  Z = tf.compat.v1.placeholder(tf.float32, shape = [None, z_dim])
  # Random Variable for output
  R = tf.compat.v1.placeholder(tf.float32, shape = [None, x_dim])

  #%% Discriminator
  # Discriminator
  D_W1 = tf.Variable(ut.xavier_init([x_dim, h_dim]))
  D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

  D_W2 = tf.Variable(ut.xavier_init([h_dim,h_dim]))
  D_b2 = tf.Variable(tf.zeros(shape=[h_dim]))
    
  D_W3 = tf.Variable(ut.xavier_init([h_dim,1]))
  D_b3 = tf.Variable(tf.zeros(shape=[1]))

  theta_D = [D_W1, D_W2, D_W3, D_b1, D_b2, D_b3]
    
  #%% Generator
  G_W1 = tf.Variable(ut.xavier_init([z_dim + x_dim, h_dim]))
  G_b1 = tf.Variable(tf.zeros(shape=[h_dim], dtype=tf.float32))

  G_W2 = tf.Variable(ut.xavier_init([h_dim,h_dim]))
  G_b2 = tf.Variable(tf.zeros(shape=[h_dim], dtype=tf.float32))
  
  G_W3 = tf.Variable(ut.xavier_init([h_dim,h_dim]))
  G_b3 = tf.Variable(tf.zeros(shape=[h_dim], dtype=tf.float32))

  G_W4 = tf.Variable(ut.xavier_init([h_dim, x_dim]))
  G_b4 = tf.Variable(tf.zeros(shape=[x_dim], dtype=tf.float32))
  
  theta_G = [G_W1, G_W2, G_W3, G_W4, G_b1, G_b2, G_b3, G_b4]

  #%% Generator and discriminator functions
  def generator(z, x, r):
    inputs =  tf.concat([z, x], axis = 1)
    G_h1 =    tf.nn.tanh(tf.matmul(inputs, G_W1) + G_b1)
    G_h2 =    tf.nn.tanh(tf.matmul(G_h1, G_W2) + G_b2)
    G_h3 =    tf.nn.tanh(tf.matmul(G_h2, G_W3) + G_b3)
    G_log_prob = tf.nn.sigmoid(tf.matmul(G_h3, G_W4) + G_b4)
    G_output =(1 - randomness) * G_log_prob + randomness * r

    return G_output
    
  def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_h2 = tf.nn.relu(tf.matmul(D_h1, D_W2) + D_b2)
    out = (tf.matmul(D_h2, D_W3) + D_b3)
        
    return out
    
  #%% Structure
  G_sample = generator(Z,X,R)
  D_real = discriminator(X)
  D_fake = discriminator(G_sample) 
    
  # Replacement of Clipping algorithm to Penalty term
  # 1. Line 6 in Algorithm 1

  eps = tf.compat.v1.placeholder(tf.float32, shape = [None, 1])
  X_inter = eps*X + (1. - eps) * G_sample

  # 2. Line 7 in Algorithm 1
  grad = tf.gradients(discriminator(X_inter), [X_inter])[0]
  grad_norm = tf.sqrt(tf.reduce_sum((grad)**2 + 1e-8, axis = 1))
  grad_pen = lam * tf.reduce_mean((grad_norm - 1)**2)

  # Loss function
  D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + grad_pen
    
  X_J = tf.random_shuffle(X)
  dist_X_G = tf.sqrt(tf.reduce_sum(input_tensor=tf.square(X - G_sample), axis=1))
  dist_XJ_G = tf.sqrt(tf.reduce_sum(input_tensor=tf.square(X_J - G_sample), axis=1))

  G_loss1 = -tf.reduce_mean(dist_X_G)
  #G_loss2 = -tf.reduce_mean(D_fake)
  G_loss2 = -D_loss
  
  G_loss3 = tf.reduce_mean((tf.math.sign(dist_X_G - dist_XJ_G) - (-1)) * (dist_X_G - dist_XJ_G))
  G_loss = G_loss2 + lamda * G_loss1 + beta * G_loss3

  # Solver
  D_solver = (tf.compat.v1.train.AdamOptimizer(learning_rate = lr, beta1 = 0.5).minimize(D_loss, var_list = theta_D))
  G_solver = (tf.compat.v1.train.AdamOptimizer(learning_rate = lr, beta1 = 0.5).minimize(G_loss, var_list = theta_G))
            
  #%% Iterations
  sess = tf.compat.v1.Session()
  sess.run(tf.compat.v1.global_variables_initializer())
        
  # Iterations
  wd_whole_set = []
  id_whole_set = []

  eps_mb = ut.sample_eps(mb_size)
  eps_test = ut.sample_eps(no_test)
  eps_train = ut.sample_eps(no_train)
  for it in tqdm(range(iterations)):
    # Discriminator training
    for _ in range(500):    
      Z_mb = ut.sample_Z(mb_size, z_dim)
      output_mb = ut.sample_R(mb_size, x_dim)

      X_idx = ut.sample_X(no_train,mb_size)    
      X_mb = orig_data_arr[X_idx,:]  

      _ = sess.run([D_solver], feed_dict = {X: X_mb, Z: Z_mb, eps: eps_mb, R: output_mb})
            
    # Generator Training
    Z_mb = ut.sample_Z(mb_size, z_dim)
    output_mb = ut.sample_R(mb_size, x_dim)

    X_idx = ut.sample_X(no_train,mb_size)
    X_mb = orig_data_arr[X_idx,:]

    _ = sess.run([G_solver], feed_dict = {X: X_mb, Z: Z_mb, eps: eps_mb, R: output_mb})
    rand_sample_for_whole_set = ut.sample_Z(no_train, z_dim)
    output_random_for_whole_set = ut.sample_R(no_train, x_dim)
    dist_loss, identifiability = sess.run([D_loss, G_loss1], feed_dict = {X: orig_train, Z: rand_sample_for_whole_set, eps: eps_train, R: output_random_for_whole_set})
    wd_whole_set.append(-dist_loss)
    id_whole_set.append(identifiability)

  #%% Output Generation
  eps_whole = ut.sample_eps(no_orig)
  synth_data = sess.run([G_sample, Z, R], feed_dict = {Z: ut.sample_Z(no_orig, z_dim), X: orig_data_arr, eps: eps_whole, R: ut.sample_R(no_orig, x_dim)})
  synth_data = synth_data[0]
    
  # Renormalization
  synth_data = ut.data_renormalization(synth_data, normalization_params)
  
  # Binary features
  for i in range(x_dim):
    if len(np.unique(orig_data_arr[:, i])) <= 2:
      synth_data[:, i] = np.round(synth_data[:, i]).astype(int)

  return_df = pd.DataFrame(data = synth_data, columns = gen_data_df.columns)
  
  round_cols = ut.col_to_round(return_df)
  return_df[round_cols] = return_df[round_cols].round(0)
  return_df = return_df.astype(orig_data_df.dtypes)

  return return_df, wd_whole_set, id_whole_set