"""
Generating high-fidelity privacy-conscious synthetic patient data for causal effect estimation with multiple treatments

Reference: Jinsung Yoon, Lydia N. Drumright, Mihaela van der Schaar, 
"Anonymization through Data Synthesis using Generative Adversarial Networks (ADS-GAN):
A harmonizing advancement for AI in medicine," 
IEEE Journal of Biomedical and Health Informatics (JBHI), 2019.
Paper link: https://ieeexplore.ieee.org/document/9034117
Last updated Date: December 22th 2020
Code author: Jinsung Yoon (jsyoon0823@gmail.com)
-----------------------------
compute_wd.py
- Compare Wasserstein distance between original data and synthetic data
"""

import numpy as np
import tensorflow as tf
from tqdm import tqdm
import utils as ut
import pandas as pd

def compute_wd (orig_data, synth_data, mb_size = 256, h_dim = 50, iterations = 500):
  """Compare Wasserstein distance between original data and synthetic data.
  
  Args:
    orig_data: original data
    synth_data: synthetically generated data
    params: Network parameters
      mb_size: mini-batch size
      h_dim: hidden state dimension
      iterations: training iterations
      
  Returns:
    WD_value: Wasserstein distance
  """
  
  all_cols = sorted(orig_data.columns)
  orig_copy = pd.DataFrame(columns = all_cols)
  synth_copy = pd.DataFrame(columns = all_cols)

  orig_copy[all_cols] = orig_data[all_cols]
  synth_copy[all_cols] = synth_data[all_cols]

  # Preprocess the data
  orig_data_arr = np.asarray(orig_copy)
  synth_data_arr = np.asarray(synth_copy)

  orig_data_arr, _ = ut.data_normalization(orig_data_arr)
  synth_data_arr, _ = ut.data_normalization(synth_data_arr)

  no, x_dim = np.shape(orig_data_arr)
    
  # Divide train / test
  orig_data_train = orig_data_arr[:int(no/2),:]
  orig_data_test = orig_data_arr[int(no/2):,:]
    
  synth_data_train = synth_data_arr[:int(no/2),:]
  synth_data_test = synth_data_arr[int(no/2):,:]
    
  #%% Placeholder
  X = tf.compat.v1.placeholder(tf.float32, shape = [None, x_dim])   
  X_hat = tf.compat.v1.placeholder(tf.float32, shape = [None, x_dim])  
      
  #%% Discriminator
  # Discriminator
  D_W1 = tf.Variable(ut.xavier_init([x_dim, h_dim]))
  D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
    
  D_W2 = tf.Variable(ut.xavier_init([h_dim,1]))
  D_b2 = tf.Variable(tf.zeros(shape=[1]))

  theta_D = [D_W1, D_W2, D_b1, D_b2]
    
  def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    out = (tf.matmul(D_h1, D_W2) + D_b2)
    return out
    
  # Structure
  D_real = discriminator(X)
  D_fake = discriminator(X_hat) 
    
  D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
    
  D_solver = (tf.compat.v1.train.RMSPropOptimizer(learning_rate=1e-4)
              .minimize(-D_loss, var_list=theta_D))
    
  clip_D = [p.assign(tf.clip_by_value(p, -0.1, 0.1)) for p in theta_D]
            
  #%%
  sess = tf.compat.v1.Session()
  sess.run(tf.compat.v1.global_variables_initializer())

  # Iterations
  loss_change_train = []
  loss_change_test = []
  for it in tqdm(range(iterations)):
            
    X_idx = ut.sample_X(int(no/2),mb_size)        
    X_mb = orig_data_train[X_idx,:]   
    X_hat_mb = synth_data_train[X_idx,:]  
            
    _, D_loss_curr, _ = sess.run([D_solver, D_loss, clip_D], feed_dict = {X: X_mb, X_hat: X_hat_mb})
    loss_change_train.append(D_loss_curr)
    WD_value = sess.run([D_loss], feed_dict = {X: orig_data_test, X_hat: synth_data_test})
    loss_change_test.append(WD_value[0])

  #mv_loss_change_train = ut.moving_average(loss_change_train, 30)
  #mv_loss_change_test = ut.moving_average(loss_change_test, 30)
  #import matplotlib.pyplot as plt
  #plt.plot(mv_loss_change_train)
  #plt.plot(mv_loss_change_test)
  #plt.show()

  #%% Test
  WD_value = sess.run([D_loss], feed_dict = {X: orig_data_test, X_hat: synth_data_test})

  return WD_value[0]