"""
Generating high-fidelity privacy-conscious synthetic patient data for causal effect estimation with multiple treatments
Treatment outcome estimator class
"""

#%% Import necessary packages
import tensorflow as tf
import numpy as np
import pandas as pd
from tqdm import tqdm
import utils as ut
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import pickle

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import euclidean_distances


class TreatmentOutcomeEstimator(BaseEstimator):
    """ A treatment outcome estimator that estimate outcome and effects
    Parameters
    ----------
    mb_size : integer, batch size
    h_dim: number of neurons in hidden layer
    iterations: number of iterations
    lr: learning rate
    beta1: beta value in AdamOptimizer
    test_size: train test split value
    out_randomness: randomness to add to the outcome
    gamma: strength of regularizer
    set_drop_rate: drop out rate
    """


    def __init__(self, mb_size = 128, h_dim = 30, iterations = 2000, lr = 1e-4, beta1 = 0.5, test_size = 0.3, out_randomness = 0.3, gamma = 0.01, set_drop_rate = 0.5):
        self.mb_size = mb_size
        self.h_dim = h_dim
        self.iterations = iterations
        self.lr = lr
        self.beta1 = beta1
        self.test_size = test_size
        self.out_randomness = out_randomness
        self.gamma = gamma
        self.set_drop_rate = set_drop_rate

    def save_info(self, file_name):
        # save all class parameters
        all_parms_to_save = [self.mb_size, self.h_dim, self.iterations, self.lr, self.beta1, self.test_size, self.out_randomness, self.gamma, self.set_drop_rate,
              self.t_columns_, self.covar_columns_, self.is_fitted_, self.covar_train_norm_params_, self.treat_train_norm_params_, 
              self.y_train_norm_params_, self.r2_train_, self.r2_test_, self.effects_]
        pickle_file = open('data/' + file_name + '.pickle','wb')
        pickle.dump(all_parms_to_save, pickle_file)

        # save all tensorflow session, graph, and variables
        saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
        saver.save(self.sess_, 'data/' + file_name + '.ckpt')
        tf.io.write_graph(self.sess_.graph_def, 'data/', file_name + '.pb', False)


    def load_info(self, file_name, X, y):
        
        #load model parameters
        pickle_file = open('data/' + file_name + '.pickle','rb')
        all_parms_to_load = pickle.load(pickle_file)
        [self.mb_size, self.h_dim, self.iterations, self.lr, self.beta1, self.test_size, self.out_randomness, self.gamma, self.set_drop_rate,
         self.t_columns_, self.covar_columns_, self.is_fitted_, self.covar_train_norm_params_, self.treat_train_norm_params_, 
         self.y_train_norm_params_, self.r2_train_, self.r2_test_, self.effects_] = all_parms_to_load
        
        self.iterations = 100
        self.fit(X,y)

        # Restore some attributes that are altered in fit step
        # This step is a temporary solution, need to be fixed later
        # Permenent fix would be to seperate tensorflow graph construction from training
        [self.mb_size, self.h_dim, self.iterations, self.lr, self.beta1, self.test_size, self.out_randomness, self.gamma, self.set_drop_rate,
         self.t_columns_, self.covar_columns_, self.is_fitted_, self.covar_train_norm_params_, self.treat_train_norm_params_, 
         self.y_train_norm_params_, self.r2_train_, self.r2_test_, self.effects_] = all_parms_to_load

        # load all tensor related 
        self.sess_ = tf.compat.v1.Session()
        saver = tf.compat.v1.train.Saver()
        saver.restore(self.sess_, 'data/' + file_name + '.ckpt')

    def fit(self, X, y, t_columns = None, covar_columns = None):
        """A reference implementation of a fitting function.
        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            The training input samples.
        y : array-like, shape (n_samples,)
            The target values .
        Returns
        -------
        self : object
            Returns self.
        """
        #X, y = check_X_y(X, y, accept_sparse=True)

        if t_columns == None:
          self.t_columns_ = list((X.columns[[x[:6] == 'drugs_' for x in X.columns]]).values)
        else:
          self.t_columns_ = t_columns

        if covar_columns == None:
          self.covar_columns_ = list(ut.select_covar_columns(X.columns))
        else:
          self.covar_columns_ = covar_columns

        assert(len(self.t_columns_) > 0)

        covar_df = X[self.covar_columns_]
        treat_df = X[self.t_columns_]

        covar_train, covar_test, treat_train, treat_test, y_train, y_test = train_test_split(covar_df, treat_df, y, test_size = self.test_size, random_state=10)

        # Reset the tensorflow graph
        tf.compat.v1.reset_default_graph()

        # dimensions
        x_dim = len(covar_train.columns)
        t_dim = len(treat_train.columns)
        o_dim = 1
        sample_no = covar_train.shape[0]

        #%% Data Preprocessing
        covar_train_arr = np.asarray(covar_train)
        covar_test_arr = np.asarray(covar_test)
        treat_train_arr = np.asarray(treat_train)
        treat_test_arr = np.asarray(treat_test)
        y_train_arr = np.asarray(y_train)
        y_test_arr = np.asarray(y_test)

        # normalization for training data
        covar_train_arr, self.covar_train_norm_params_ = ut.data_normalization(covar_train_arr)
        treat_train_arr, self.treat_train_norm_params_ = ut.data_normalization(treat_train_arr)
        y_train_arr, self.y_train_norm_params_ = ut.data_normalization(y_train_arr)
        y_train_arr = y_train_arr.reshape([-1, 1])


        # normalization for test data
        covar_test_arr, _ = ut.data_normalization(covar_test_arr, normalization_params = self.covar_train_norm_params_)
        treat_test_arr, _ = ut.data_normalization(treat_test_arr, normalization_params = self.treat_train_norm_params_)
        y_test_arr, _ = ut.data_normalization(y_test_arr, normalization_params = self.y_train_norm_params_)
        y_test_arr = y_test_arr.reshape([-1, 1])

        #%% Placeholder: Feature, Tretment, Outcome
        covar = tf.compat.v1.placeholder(tf.float64, shape = [None, x_dim], name = 'covar')
        treat = tf.compat.v1.placeholder(tf.float64, shape = [None, t_dim], name = 'treat')
        O_groundtruth = tf.compat.v1.placeholder(tf.float64, shape = [None, o_dim])
            
        #Estimator
        W1 = tf.Variable(ut.xavier_init([x_dim, self.h_dim]), dtype=tf.float64)
        b1 = tf.Variable(tf.zeros(shape=[self.h_dim], dtype=tf.float64), dtype=tf.float64)

        W2 = tf.Variable(ut.xavier_init([self.h_dim,self.h_dim]), dtype=tf.float64)
        b2 = tf.Variable(tf.zeros(shape=[self.h_dim], dtype=tf.float64), dtype=tf.float64)
          
        W3 = tf.Variable(ut.xavier_init([self.h_dim, o_dim]), dtype=tf.float64)
        b3 = tf.Variable(tf.zeros(shape=[o_dim], dtype=tf.float64), dtype=tf.float64)

        T_W = tf.Variable(ut.xavier_init([t_dim, o_dim]), dtype=tf.float64)
        T_b = tf.Variable(tf.zeros(shape=[o_dim], dtype=tf.float64), dtype=tf.float64)

        theta_E = [W1, W2, W3, T_W, b1, b2, b3, T_b]

        drop_rate = tf.compat.v1.placeholder(tf.float64, shape=(), name = 'drop_rate')

        def estimator(x, t):
          x_drop = tf.nn.dropout(x, rate = drop_rate)
          t_drop = tf.nn.dropout(t, rate = drop_rate)
          h1 = tf.nn.relu(tf.matmul(x_drop, W1) + b1)
          h1_drop = tf.nn.dropout(h1, rate = drop_rate)
          h2 = tf.nn.relu(tf.matmul(h1_drop, W2) + b2)
          h2_drop = tf.nn.dropout(h2, rate = drop_rate)
          out_1 = tf.nn.sigmoid(tf.matmul(h2_drop, W3) + b3) 
          out_2 = tf.matmul(t_drop, T_W) + T_b
          out = out_1 + out_2    
          return out

        #%% Structure
        o_estimate = estimator(covar, treat)
        o_estimate = tf.identity(o_estimate, name = 'o_estimate')

        # Loss function
        regularizers = tf.nn.l2_loss(W1) + tf.nn.l2_loss(W2) + tf.nn.l2_loss(W3) + tf.nn.l2_loss(T_W)
        O_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(input_tensor=tf.square(O_groundtruth - o_estimate), axis=1))) + self.gamma * regularizers
          
        # Solver
        O_solver = (tf.compat.v1.train.AdamOptimizer(learning_rate = self.lr, beta1 = self.beta1).minimize(O_loss, var_list = theta_E))
                  
        #%% Open a session
        self.sess_ = tf.compat.v1.Session()
        self.sess_.run(tf.compat.v1.global_variables_initializer())
              
        # Training Iterations
        loss_iter = []
        for it in tqdm(range(self.iterations)):
        #for it in tqdm(range(20)):dd
          X_idx = ut.sample_X(sample_no,self.mb_size)
          covar_mb = covar_train_arr[X_idx,:]
          treat_mb = treat_train_arr[X_idx,:]
          O_mb = y_train_arr[X_idx]
          _, O_loss_curr = self.sess_.run([O_solver, O_loss], feed_dict = {covar: covar_mb, treat: treat_mb, O_groundtruth: O_mb, drop_rate: self.set_drop_rate})
          loss_iter.append(O_loss_curr)

        self.is_fitted_ = True

        estimations = self.sess_.run([o_estimate, T_W], feed_dict = {covar: covar_train_arr, treat: treat_train_arr, drop_rate: 0.0})
        self.r2_train_ = r2_score(y_train_arr, estimations[0])
        # note here the effects_ should be in original scale
        self.effects_ = estimations[1] * self.y_train_norm_params_['max_val']
        estimations = self.sess_.run([o_estimate], feed_dict = {covar: covar_test_arr, treat: treat_test_arr, drop_rate: 0.0})
        self.r2_test_ = r2_score(y_test_arr, estimations[0])

        #import matplotlib.pyplot as plt
        #plt.figure()
        #plt.plot(loss_iter, label = 'estimated outcome')
        #plt.show()

        # 'fit' should always return 'self'
        return self

    def predict(self, X):
        """ An implementation of a predicting function.
        Parameters
        ----------
        X : data frame, shape (n_samples, n_features)
            The input features.
        Returns
        -------
        y : ndarray, shape (n_samples,)
            Returns an array of precitions.
        """
        #X = check_array(X, accept_sparse=True)
        check_is_fitted(self, 'is_fitted_')

        covar_df = X[self.covar_columns_]
        treat_df = X[self.t_columns_]

        covar_arr = np.asarray(covar_df)
        treat_arr = np.asarray(treat_df)
        covar_arr, _ = ut.data_normalization(covar_arr, normalization_params = self.covar_train_norm_params_)
        treat_arr, _ = ut.data_normalization(treat_arr, normalization_params = self.treat_train_norm_params_)

        # change E_Wt and Ebt to adjust the treatment effects
        o_estimate = tf.compat.v1.get_default_graph().get_tensor_by_name("o_estimate:0")
        covar = tf.compat.v1.get_default_graph().get_tensor_by_name("covar:0")
        treat = tf.compat.v1.get_default_graph().get_tensor_by_name("treat:0")
        drop_rate = tf.compat.v1.get_default_graph().get_tensor_by_name("drop_rate:0")

        estimations = self.sess_.run([o_estimate], feed_dict = {covar: covar_arr, treat: treat_arr, drop_rate: 0.0})
        o_predicted = estimations[0]
        # Renormalization
        o_predicted = ut.data_renormalization(o_predicted, self.y_train_norm_params_)

        random_arr = np.random.normal(size = o_predicted.shape, loc = np.mean(o_predicted), scale= np.std(o_predicted))
        predicted_return = (1 - self.out_randomness) * o_predicted + self.out_randomness * random_arr

        return predicted_return

# %%
