"""GANITE Codebase.

Reference: Jinsung Yoon, James Jordon, Mihaela van der Schaar, 
"GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets", 
International Conference on Learning Representations (ICLR), 2018.

Paper link: https://openreview.net/forum?id=ByKWUeWA-

Last updated Date: April 25th 2020
Code author: Jinsung Yoon (jsyoon0823@gmail.com)

-----------------------------

ganite.py

(1) Import data
(2) Train GANITE & Estimate potential outcomes
(3) Evaluate the performances
  - PEHE
  - ATE
"""

## Necessary packages
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np
import warnings
warnings.filterwarnings("ignore")
from logging import getLogger

# 1. GANITE model
from causally.model.ganite.model import model
# 2. Data loading
from causally.model.ganite.data_loading import data_loading
# 3. Metrics
from causally.model.ganite.metrics import PEHE, ATE


def ganite (config):

    logger = getLogger()
    metric_results = {'out_ate':[],'out_pehe':[],'in_ate':[],'in_pehe':[]}
    config['epochs'] = 3000
    while config['start_order'] <= config['end_order']:

        train_x, train_t, train_y, train_potential_y, test_x, test_potential_y = \
        data_loading(config)

        logger.info(config['model'] + ' dataset is ready.')

        ## Potential outcome estimations by GANITE
        test_y_hat = model(train_x, train_t, train_y, test_x, config)
        logger.info('Finish GANITE training and potential outcome estimations')

        # 1. PEHE
        test_PEHE = PEHE(test_potential_y, test_y_hat)
        metric_results['out_pehe'].append(np.round(test_PEHE, 4))

        # 2. ATE
        test_ATE = ATE(test_potential_y, test_y_hat)
        metric_results['out_ate'].append(np.round(test_ATE, 4))

        ## Print performance metrics on testing data
        logger.info('[{},{}] test result:\n{}'.format(
          config['model'],config['dataset']+str(config['start_order']),
          test_PEHE,test_ATE ))
        config['start_order'] += 1

    config['start_order'] = 1
    while config['start_order'] <= config['end_order']:
        train_x, train_t, train_y, train_potential_y, test_x, test_potential_y = \
        data_loading(config)

        logger.info(config['model'] + ' dataset is ready.')

        ## Potential outcome estimations by GANITE
        test_y_hat = model(train_x, train_t, train_y, train_x, config)
        logger.info('Finish GANITE training and potential outcome estimations')

        # 1. PEHE
        test_PEHE = PEHE(train_potential_y, test_y_hat)
        metric_results['in_pehe'].append(np.round(test_PEHE, 4))

        # 2. ATE
        test_ATE = ATE(train_potential_y, test_y_hat)
        metric_results['in_ate'].append(np.round(test_ATE, 4))

        ## Print performance metrics on testing data
        logger.info('[{},{}] test result:\n{}'.format(
          config['model'],config['dataset']+str(config['start_order']),
          test_PEHE,test_ATE ))
        config['start_order'] += 1

    return metric_results


if __name__ == '__main__':  
  
  # Inputs for the main function
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--data_name',
      choices=['twin'],
      default='twin',
      type=str)
  parser.add_argument(
      '--train_rate',
      help='the ratio of training data',
      default=0.8,
      type=float)
  parser.add_argument(
      '--h_dim',
      help='hidden state dimensions (should be optimized)',
      default=30,
      type=int)
  parser.add_argument(
      '--iteration',
      help='Training iterations (should be optimized)',
      default=10000,
      type=int)
  parser.add_argument(
      '--batch_size',
      help='the number of samples in mini-batch (should be optimized)',
      default=256,
      type=int)
  parser.add_argument(
      '--alpha',
      help='hyper-parameter to adjust the loss importance (should be optimized)',
      default=1,
      type=int)
  
  args = parser.parse_args() 
  
  # Calls main function  
  test_y_hat, metrics = ganite(args)