#!/usr/bin/env python
"""CEVAE model on IHDP
"""


import os
import edward as ed
import tensorflow as tf
# from sacred import Experiment
from copy import deepcopy
import time
import inspect

# ex = Experiment('cevae mods')

# from tensorflow.contrib.distributions import Normal
from edward.models import Bernoulli, Normal
from progressbar import ETA, Bar, Percentage, ProgressBar

from data import IHDP
from evaluation import Evaluator
import numpy as np
import time
from scipy.stats import sem

from utils import fc_net, get_y0_y1, wasserstein, safe_sqrt
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('-reps', type=int, default=100)
# parser.add_argument('-earl', type=int, default=10)
parser.add_argument('-lr', type=float, default=0.001)
parser.add_argument('-opt', choices=['adam', 'adamax'], default='adam')
# parser.add_argument('-epochs', type=int, default=100)
parser.add_argument('-print_every', type=int, default=10)
args = parser.parse_args()

args.true_post = True


dataset = IHDP(replications=args.reps)
dimx = 25
scores = np.zeros((args.reps, 3))
scores_test = np.zeros((args.reps, 3))


M = None  # batch size during training

lamba = 1e-4  # weight decay
nh, h = 3, 200  # number and size of hidden layers


# @ex.config
# def cfg():
#     equiv_confounder = True
#     equiv_conf_covar = True
#     # direct_conf_covar = False
#     # num_z_conf = 10
#     decoder_switch = False
#     decoder_sigma = True
#     # second_switch = False
#     # balanced_sample = True
#     ipm_matching = True


# @ex.automain
def run(train_dir='/Users/abel/Downloads/CEVAE/models/ihdp/',
        d=20,  # latent dimension
        nh_t=20,
        equiv_confounder=False,
        equiv_conf_covar=False,
        direct_conf_covar=False,
        dim_conf=1,
        dim_z_conf=20,
        dec_repre_dim=200,
        decoder_switch=True,
        decoder_sigma=False,
        first_switch=True,
        first_repre=True,
        qy_sigma=False,
        second_switch=True,
        second_repre=True,
        nh_qz=200,
        qz_sigma=True,
        balanced_sample=False,
        imp_weighting=False,
        eval_weighting=False,
        ipm_matching=False,
        enc_alpha = 1,
        dec_matching=False,
        dec_alpha = 1,
        qz_matching=False,
        qz_alpha=1,
        enc_t=True,
        enc_obs_t=False,
        enc_x=True,
        dec_x=True,
        dec_t=True,
        dec_obs_t=False,
        no_vae=False,
        infer_y=True,
        infer_qy=True,
        eval_qy=False,
        dec_ate=False,
        dec_ate_post=False,
        validate='elbo',
        validate_step=10,
        epochs=100,
        index=None,
        ):
    all_args = deepcopy(locals())
    print(all_args)

    global ARGSPEC
    ARGSPEC = inspect.getfullargspec(run)
    non_defaults = {arg: all_args[arg] for arg, default in zip(ARGSPEC.args, ARGSPEC.defaults) if default != all_args[arg]}

    train_dir += time.strftime("%Y-%m-%d %H:%M:%S ") + str(non_defaults).replace(':', '-') + '/'
    if not os.path.exists(train_dir):
        print("make savefolder: {0:s}...".format(train_dir))
        os.makedirs(train_dir)  # Make folder

    def printf(*args, filename=train_dir+'output.txt'):
        print(*args)
        with open(filename, 'a') as f:
            print(*args, file=f)

    def mu_sigma(input, layers, nout, var_sigma, switch, name):
        def _mu_sigma(name):
            if var_sigma:
                return fc_net(input, layers, [[nout, None], [nout, tf.nn.softplus]], name, lamba=lamba,
                                   activation=activation)
            else:
                mu = fc_net(input, layers, [[nout, None]], name, lamba=lamba,
                            activation=activation)
                return mu, tf.ones_like(mu)

        if switch:
            return _mu_sigma(name+'_t0'), _mu_sigma(name+'_t1')
        else:
            return _mu_sigma(name)

    def normal(input, layers, nout, var_sigma, switch, t, name, reweight=1.):
        if switch:
            ((mu0, sigma0), (mu1, sigma1)) = mu_sigma(input, layers, nout, var_sigma, True, name+'_mu_sigma')
            return Normal(loc=t*mu1+(1.-t)*mu0, scale=reweight*(t*sigma1+(1.-t)*sigma0), name=name)
        else:
            mu, sigma = mu_sigma(tf.concat([input, t], 1), layers, nout, var_sigma, False, name+'_mu_sigma')
            return Normal(loc=mu, scale=reweight*sigma, name=name)

    for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test(index)):
        print('\nReplication {}/{}'.format(i + 1, args.reps))
        (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
        (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
        (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
        evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)
        evaluator_validation = Evaluator(yva, tva, y_cf=y_cfva, mu0=mu0va, mu1=mu1va)
        evaluator_train = Evaluator(ytr, ttr, y_cf=y_cftr, mu0=mu0tr, mu1=mu1tr)

        # reorder features with binary first and continuous after
        perm = binfeats + contfeats
        xtr, xva, xte = xtr[:, perm], xva[:, perm], xte[:, perm]

        # xalltr, talltr, yalltr = np.concatenate([xtr, xva], axis=0), np.concatenate([ttr, tva], axis=0), np.concatenate([ytr, yva], axis=0)
        # evaluator_train = Evaluator(yalltr, talltr, y_cf=np.concatenate([y_cftr, y_cfva], axis=0),
        #                             mu0=np.concatenate([mu0tr, mu0va], axis=0), mu1=np.concatenate([mu1tr, mu1va], axis=0))

        # zero mean, unit variance for y during training
        ym, ys = np.mean(ytr), np.std(ytr)
        ytr, yva = (ytr - ym) / ys, (yva - ym) / ys
        best_logpvalid = - np.inf
        best_pehevalid = np.inf
        best_atevalid = np.inf

        reweighting = imp_weighting or eval_weighting
        if reweighting:
            p_treat = np.mean(ttr)

        with tf.Graph().as_default():

            ed.set_seed(1)
            np.random.seed(1)
            tf.set_random_seed(1)

            sess = tf.InteractiveSession()

            x_ph_bin = tf.placeholder(tf.float32, [M, len(binfeats)], name='x_bin')  # binary inputs
            x_ph_cont = tf.placeholder(tf.float32, [M, len(contfeats)], name='x_cont')  # continuous inputs
            t_ph = tf.placeholder(tf.float32, [M, 1])
            y_ph = tf.placeholder(tf.float32, [M, 1])

            x_ph = tf.concat([x_ph_bin, x_ph_cont], 1)
            activation = tf.nn.elu

            # CEVAE model
            # @@ ENCODER
            # p(z)
            z = Normal(loc=tf.zeros([tf.shape(t_ph)[0], d]), scale=tf.ones([tf.shape(t_ph)[0], d]))

            # 1-dim equiv confounder
            if equiv_confounder:
                z_conf, z_unconf = z[:, :dim_z_conf], z[:, dim_z_conf:]
                conf = fc_net(z_conf, nh * [h], [[dim_conf, None]], 'equiv_confounder', lamba=lamba, activation=activation)

            if dec_x:
            # p(x|z)
                if not (equiv_confounder and equiv_conf_covar):
                    hx = fc_net(z, (nh - 1) * [h], [], 'px_z_shared', lamba=lamba, activation=activation)
                    logits = fc_net(hx, [h], [[len(binfeats), None]], 'px_z_bin'.format(i + 1), lamba=lamba, activation=activation)
                    x1 = Bernoulli(logits=logits, dtype=tf.float32, name='bernoulli_px_z')

                    mu, sigma = fc_net(hx, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont', lamba=lamba,
                                       activation=activation)
                    x2 = Normal(loc=mu, scale=sigma, name='gaussian_px_z')
                else:
                    # logits = fc_net(conf, [h], [[len(binfeats), None]], 'px_z_bin'.format(i + 1), lamba=lamba,
                    #                 activation=activation)
                    # x1 = Bernoulli(logits=logits, dtype=tf.float32, name='bernoulli_px_z')
                    #
                    # mu, sigma = fc_net(conf, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont',
                    #                    lamba=lamba,
                    #                    activation=activation)
                    # x2 = Normal(loc=mu, scale=sigma, name='gaussian_px_z')

                    # impure proxy
                    if not direct_conf_covar:
                        hx = fc_net(z_conf, (nh - 1) * [h], [], 'px_z_shared', lamba=lamba, activation=activation)
                        hx2 = fc_net(z_unconf, (nh - 1) * [h], [], 'px_z2_shared', lamba=lamba, activation=activation)
                        impure_conf = tf.concat([hx, hx2], 1)
                    else:
                        hx2 = fc_net(z_unconf, (nh - 1) * [h], [], 'px_z2_shared', lamba=lamba, activation=activation)
                        impure_conf = tf.concat([conf, hx2], 1)

                    logits = fc_net(impure_conf, [h], [[len(binfeats), None]], 'px_z_bin'.format(i + 1), lamba=lamba,
                                    activation=activation)
                    x1 = Bernoulli(logits=logits, dtype=tf.float32, name='bernoulli_px_z')

                    mu, sigma = fc_net(impure_conf, [h], [[len(contfeats), None], [len(contfeats), tf.nn.softplus]], 'px_z_cont',
                                       lamba=lamba,
                                       activation=activation)
                    x2 = Normal(loc=mu, scale=sigma, name='gaussian_px_z')

            # p(t|z)
            if not equiv_confounder:
                logits = fc_net(z, [h], [[1, None]], 'pt_z', lamba=lamba, activation=activation)
            else:
                logits = fc_net(conf, [h], [[1, None]], 'pt_z', lamba=lamba, activation=activation)
                # logits = conf
            t = Bernoulli(logits=logits, dtype=tf.float32) if not dec_obs_t else t_ph

            # p(y|t,z)
            if not equiv_confounder:
                # if decoder_switch:
                #     # original
                #     if not decoder_sigma:
                #         mu2_t0 = fc_net(z, nh * [h], [[1, None]], 'py_t0z', lamba=lamba, activation=activation)
                #         mu2_t1 = fc_net(z, nh * [h], [[1, None]], 'py_t1z', lamba=lamba, activation=activation)
                #         y = Normal(loc=t * mu2_t1 + (1. - t) * mu2_t0, scale=tf.ones_like(mu2_t0))
                #     else:
                #         mu2_t0, sigma2_t0 = fc_net(z, nh * [h], [[1, None], [1, tf.nn.softplus]], 'py_t0z', lamba=lamba, activation=activation)
                #         mu2_t1, sigma2_t1 = fc_net(z, nh * [h], [[1, None], [1, tf.nn.softplus]], 'py_t1z', lamba=lamba, activation=activation)
                #         y = Normal(loc=t * mu2_t1 + (1. - t) * mu2_t0, scale=t * sigma2_t1 + (1. - t) * sigma2_t0)
                # else:
                #     if not decoder_sigma:
                #         mu2 = fc_net(z, nh * [h], [[1, None]], 'py_t0z', lamba=lamba, activation=activation)
                #         y = Normal(loc=mu2, scale=tf.ones_like(mu2))
                #     else:
                #         mu2, sigma2 = fc_net(z, nh * [h], [[1, None], [1, tf.nn.softplus]], 'py_t0z', lamba=lamba, activation=activation)
                #         y = Normal(loc=mu2, scale=sigma2)

                if not dec_matching:
                    dec_repre = z
                else:
                    dec_repre = fc_net(z, (nh-1) * [h], [[dec_repre_dim, None]], 'dec_repre', lamba=lamba, activation=activation)

                y = normal(dec_repre, [h] if dec_matching else nh * [h], 1, decoder_sigma, decoder_switch, t, 'py')

            else:
                # if not decoder_sigma:
                #     mut = fc_net(tf.concat([t, conf], 1), nh * [h], [[1, None]], 'py_tz', lamba=lamba, activation=activation)
                #     y = Normal(loc=mut, scale=tf.ones_like(mut))
                # else:
                #     mut, sigmat = fc_net(tf.concat([t, conf], 1), nh * [h], [[1, None], [1, tf.nn.softplus]], 'py_tz', lamba=lamba, activation=activation)
                #     y = Normal(loc=mut, scale=sigmat)

                y = normal(conf, nh * [h], 1, decoder_sigma, decoder_switch, t, 'py')

            # CEVAE variational approximation
            # @@ ENCODER
            # q(t|x)
            if not enc_obs_t and not no_vae:
                logits_t = fc_net(x_ph, [nh_t], [[1, None]], 'qt', lamba=lamba, activation=activation)
                qt = Bernoulli(logits=logits_t, dtype=tf.float32)
            else:
                qt = t_ph

            if reweighting:
                imp_weight = 0.5 * (qt / p_treat + (1. - qt) / (1. - p_treat))

            sigma_weight = 1. / safe_sqrt(imp_weight) if imp_weighting else 1.
            eval_weight = 1. / sigma_weight if eval_weighting else 1.

            # q(y|x,t)
            # if first_switch:
            #     hqy = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared', lamba=lamba, activation=activation)
            #
            #     if not qy_sigma:
            #
            #         mu_qy_t0 = fc_net(hqy, [h], [[1, None]], 'qy_xt0', lamba=lamba, activation=activation)
            #         mu_qy_t1 = fc_net(hqy, [h], [[1, None]], 'qy_xt1', lamba=lamba, activation=activation)
            #
            #         qy = Normal(loc=qt * mu_qy_t1 + (1. - qt) * mu_qy_t0,
            #                     scale=sigma_weight*tf.ones_like(mu_qy_t0))
            #     else:
            #         mu_qy_t0, sigma_qy_t0 = fc_net(hqy, [h], [[1, None], [1, tf.nn.softplus]], 'qy_xt0', lamba=lamba, activation=activation)
            #         mu_qy_t1, sigma_qy_t1 = fc_net(hqy, [h], [[1, None], [1, tf.nn.softplus]], 'qy_xt1', lamba=lamba, activation=activation)
            #         qy = Normal(loc=qt * mu_qy_t1 + (1. - qt) * mu_qy_t0,
            #                     scale=sigma_weight*(qt * sigma_qy_t1 + (1. - qt) * sigma_qy_t0))
            # else:
            #     mu_qy = fc_net(x_ph, nh * [h], [[1, None]], 'qy_xt1', lamba=lamba, activation=activation)
            #     qy = Normal(loc=mu_qy, scale=tf.ones_like(mu_qy))
            if not qz_matching:
                if first_repre:
                    hqy = fc_net(x_ph, (nh - 1) * [h], [], 'qy_xt_shared', lamba=lamba, activation=activation)
                else:
                    hqy = x_ph

                qy = normal(hqy, [h] if first_repre else nh*[h], 1, qy_sigma, first_switch, qt, 'qy', reweight=sigma_weight)

                # q(z|x,t,y)
                inpt2 = tf.concat([x_ph, qy], 1) if enc_x else qy
                if second_repre:
                    hqz = fc_net(inpt2, (nh - 1) * [h], [[nh_qz, None]], 'qz_xty_shared', lamba=lamba, activation=activation)
                else:
                    hqz = inpt2

                if enc_t:
                    qz = normal(hqz, [h] if second_repre else nh*[h], d, qz_sigma, second_switch, qt, 'qz_x')
                else:
                    mu_qz, sigma_qz = mu_sigma(hqz, [h] if second_repre else nh*[h], d, qz_sigma, False, 'qz')
                    qz = Normal(loc=mu_qz, scale=sigma_qz)
            else:
                if second_repre:
                    hqz = fc_net(x_ph, (nh - 1) * [h], [[nh_qz, None]], 'qz_xty_shared', lamba=lamba, activation=activation)
                else:
                    hqz = x_ph

                mu_qz, sigma_qz = mu_sigma(hqz, [h] if second_repre else nh * [h], d, qz_sigma, False, 'qz')
                qz = Normal(loc=mu_qz, scale=sigma_qz)

            # if second_switch:
            #
            #     muq_t0, sigmaq_t0 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt0', lamba=lamba,
            #                                activation=activation)
            #     muq_t1, sigmaq_t1 = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1', lamba=lamba,
            #                                activation=activation)
            #     qz = Normal(loc=qt * muq_t1 + (1. - qt) * muq_t0, scale=qt * sigmaq_t1 + (1. - qt) * sigmaq_t0)
            # else:
            #     # gxy = fc_net(inpt2, (nh - 1) * [h], [[1, None]], 'gxy', lamba=lamba, activation=activation)
            #     # hqz = fc_net(tf.concat([gxy, qt], 1), (nh - 1) * [h], [[1, None]], 'hqz', lamba=lamba, activation=activation)
            #     #
            #     # muq, sigmaq = fc_net(hqz, [h], [[d, None], [d, tf.nn.softplus]], 'muq_sigmaq', lamba=lamba,
            #     #                            activation=activation)
            #     # qz = Normal(loc=muq, scale=sigmaq)
            #
            #     muq_t1, sigmaq_t1 = fc_net(tf.concat([hqz, qt], 1), [h], [[d, None], [d, tf.nn.softplus]], 'qz_xt1', lamba=lamba,
            #                                activation=activation)
            #     qz = Normal(loc=muq_t1, scale=sigmaq_t1)


            # Create data dictionary for edward
            # data = {x1: x_ph_bin, x2: x_ph_cont, y: y_ph, qt: t_ph, t: t_ph, qy: y_ph}
            # data = {x1: x_ph_bin, x2: x_ph_cont, y: y_ph, t: t_ph}
            data = {}
            if infer_qy:
                data[qy] = y_ph
            if infer_y:
                data[y] = y_ph

            if not no_vae:
                if dec_t and not dec_obs_t:
                    data[t] = t_ph
                if dec_x:
                    data[x1] = x_ph_bin
                    data[x2] = x_ph_cont

                if enc_t and not enc_obs_t:
                    data[qt] = t_ph

            if not reweighting:
                # sample posterior predictive for p(y|z,t)
                if not eval_qy and not no_vae:
                    y_ite = ed.copy(y, {z: qz, t: t_ph}, scope='y_post')
                else:
                    y_ite = qy
            else:
                if eval_weighting:
                    y_ite = Normal(loc=qy.loc, scale=eval_weight*qy.scale)
                elif imp_weighting:
                    y_ite = Normal(loc=qy.loc, scale=qy.scale / sigma_weight)

            if dec_ate:
                y_ate = ed.copy(y, {z: qz if dec_ate_post else z, t: t_ph}, scope='y_ate')

            # crude approximation of the above
            # y_post_mean = ed.copy(y, {z: qz.mean(), t: t_ph}, scope='y_post_mean')

            # 所有 inference 应该都在 inference 对象里自动完成了
            if not no_vae:
                inference = ed.KLqp({z: qz}, data)
            else:
                inference = ed.KLqp(data=data)
            optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)
            inference.initialize(optimizer=optimizer)

            loss_op = inference.loss

            if ipm_matching:
                h_rep_norm = hqy / safe_sqrt(tf.reduce_sum(tf.square(hqy), axis=1, keep_dims=True))
                imb_distance, _ = wasserstein(h_rep_norm, qt, 0.5, lam=10.0, its=10, sq=False,
                            backpropT=True)

                loss_op += enc_alpha * imb_distance

            if dec_matching:
                dec_repre_post = ed.copy(dec_repre, {z: qz}) if dec_ate_post else dec_repre
                dec_rep_norm = dec_repre_post / safe_sqrt(tf.reduce_sum(tf.square(dec_repre_post), axis=1, keep_dims=True))
                imb_distance_dec, _ = wasserstein(dec_rep_norm, qt if dec_ate_post else t, 0.5, lam=10.0, its=10, sq=False,
                            backpropT=True)

                loss_op += dec_alpha * imb_distance_dec

            if qz_matching:
                qz_norm = qz / safe_sqrt(tf.reduce_sum(tf.square(qz), axis=1, keep_dims=True))
                imb_distance_qz, _ = wasserstein(qz_norm, qt, 0.5, lam=10.0, its=10, sq=False,
                                              backpropT=True)

                loss_op += qz_alpha * imb_distance_qz

            train_op = optimizer.minimize(loss_op)

            saver = tf.train.Saver(tf.contrib.slim.get_variables())
            tf.global_variables_initializer().run()

            if not balanced_sample:
                n_epoch, n_iter_per_epoch, idx = epochs, 10 * int(xtr.shape[0] / 100), np.arange(xtr.shape[0])
            else:
                n_epoch, _, idx = epochs, 10 * int(xtr.shape[0] / 100), np.arange(xtr.shape[0])

                from imblearn.tensorflow import balanced_batch_generator
                from imblearn.over_sampling import RandomOverSampler
                training_generator, n_iter_per_epoch = balanced_batch_generator(
                    np.hstack([xtr, ytr]), ttr, sampler=RandomOverSampler(), batch_size=100)
                n_iter_per_epoch *= 5

            # dictionaries needed for evaluation
            # tr0, tr1 = np.zeros((xalltr.shape[0], 1)), np.ones((xalltr.shape[0], 1))
            # f1 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr1}
            # f0 = {x_ph_bin: xalltr[:, 0:len(binfeats)], x_ph_cont: xalltr[:, len(binfeats):], t_ph: tr0}

            t0te, t1te = np.zeros((xte.shape[0], 1)), np.ones((xte.shape[0], 1))
            f1te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t1te}
            f0te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t0te}
            t0va, t1va = np.zeros((xva.shape[0], 1)), np.ones((xva.shape[0], 1))
            f1va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t1va}
            f0va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t0va}
            t0tr, t1tr = np.zeros((xtr.shape[0], 1)), np.ones((xtr.shape[0], 1))
            f1tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t1tr}
            f0tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t0tr}

            for epoch in range(n_epoch):
                avg_loss = 0.0

                t0 = time.time()
                widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
                pbar = ProgressBar(n_iter_per_epoch, widgets=widgets)
                pbar.start()
                np.random.shuffle(idx)
                for j in range(n_iter_per_epoch):
                    pbar.update(j)

                    if not balanced_sample:
                        batch = np.random.choice(idx, 100)
                        x_train, y_train, t_train = xtr[batch], ytr[batch], ttr[batch]
                    else:
                        data_train, t_train = next(training_generator)
                        x_train = data_train[:, :-1]
                        y_train = data_train[:, -1:]

                    # if ipm_matching:
                    _, loss = sess.run([train_op, loss_op], feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                            x_ph_cont: x_train[:, len(binfeats):],
                                                            t_ph: t_train, y_ph: y_train})
                    avg_loss += loss
                    # else:
                    #     info_dict = inference.update(feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                    #                                             x_ph_cont: x_train[:, len(binfeats):],
                    #                                             t_ph: t_train, y_ph: y_train})
                    #     avg_loss += info_dict['loss']

                avg_loss = avg_loss / n_iter_per_epoch
                avg_loss = avg_loss / 100

                if epoch % args.print_every == 0 and not validate=='pehe':
                    y0, y1 = get_y0_y1(sess, y_ite, f0tr, f1tr, shape=ytr.shape, L=1)
                    y0, y1 = y0 * ys + ym, y1 * ys + ym
                    score_train = evaluator_train.calc_stats(y1, y0)
                    rmses_train = evaluator_train.y_errors(y0, y1)

                    y0, y1 = get_y0_y1(sess, y_ite, f0te, f1te, shape=yte.shape, L=1)
                    y0, y1 = y0 * ys + ym, y1 * ys + ym
                    score_test = evaluator_test.calc_stats(y1, y0)

                    y0, y1 = get_y0_y1(sess, y_ite, f0va, f1va, shape=yva.shape, L=1)
                    y0, y1 = y0 * ys + ym, y1 * ys + ym
                    score_val = evaluator_validation.calc_stats(y1, y0)

                    print("Epoch: {}/{}, log p(x) >= {:0.3f}, \n"
                          "ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, \n"
                          # "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, \n"
                          "ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, \n"
                          "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                          "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss,
                                               score_train[0], score_train[1], score_train[2],
                                               # rmses_train[0], rmses_train[1],
                                               score_test[0], score_test[1], score_test[2],
                                               score_val[0], score_val[1], score_val[2],
                                               time.time() - t0))

                if epoch % validate_step == 0 or epoch == (n_epoch - 1):
                    if validate == 'elbo':
                        # construct a deterministic version (i.e. use the mean of the approximate posterior) of the lower bound
                        # for early stopping according to a validation set
                        if not enc_obs_t:
                            y_post_eval = ed.copy(y, {z: qz.mean(), qt: t_ph, qy: y_ph, t: t_ph}, scope='y_post_eval')
                            if dec_x:
                                x1_post_eval = ed.copy(x1, {z: qz.mean(), qt: t_ph, qy: y_ph}, scope='x1_post_eval')
                                x2_post_eval = ed.copy(x2, {z: qz.mean(), qt: t_ph, qy: y_ph}, scope='x2_post_eval')
                            t_post_eval = ed.copy(t, {z: qz.mean(), qt: t_ph, qy: y_ph}, scope='t_post_eval')
                        else:
                            y_post_eval = ed.copy(y, {z: qz.mean(), qy: y_ph, t: t_ph}, scope='y_post_eval')
                            if dec_x:
                                x1_post_eval = ed.copy(x1, {z: qz.mean(), qy: y_ph}, scope='x1_post_eval')
                                x2_post_eval = ed.copy(x2, {z: qz.mean(), qy: y_ph}, scope='x2_post_eval')
                            t_post_eval = ed.copy(t, {z: qz.mean(), qy: y_ph}, scope='t_post_eval')

                        logp_valid = tf.reduce_mean(
                            tf.reduce_sum(y_post_eval.log_prob(y_ph) + t_post_eval.log_prob(t_ph), axis=1) +
                            tf.reduce_sum(x1_post_eval.log_prob(x_ph_bin), axis=1) if dec_x else 0 +
                            tf.reduce_sum(x2_post_eval.log_prob(x_ph_cont), axis=1) if dec_x else 0 +
                            tf.reduce_sum(z.log_prob(qz.mean()) - qz.log_prob(qz.mean()), axis=1)
                        )

                        logpvalid = sess.run(logp_valid, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                                                                    x_ph_cont: xva[:, len(binfeats):],
                                                                    t_ph: tva, y_ph: yva})
                        if logpvalid >= best_logpvalid:
                            print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid,
                                                                                                 logpvalid))
                            best_logpvalid = logpvalid

                            saver.save(sess, train_dir + 'model')
                    elif validate == 'pehe':
                        y0, y1 = get_y0_y1(sess, y_ite, f0tr, f1tr, shape=ytr.shape, L=100 if not reweighting else 1)
                        y0, y1 = y0 * ys + ym, y1 * ys + ym
                        score_train = evaluator_train.calc_stats(y1, y0)
                        rmses_train = evaluator_train.y_errors(y0, y1)

                        y0, y1 = get_y0_y1(sess, y_ite, f0te, f1te, shape=yte.shape, L=100 if not reweighting else 1)
                        y0, y1 = y0 * ys + ym, y1 * ys + ym
                        score_test = evaluator_test.calc_stats(y1, y0)

                        y0, y1 = get_y0_y1(sess, y_ite, f0va, f1va, shape=yva.shape, L=100 if not reweighting else 1)
                        y0, y1 = y0 * ys + ym, y1 * ys + ym
                        score_val = evaluator_validation.calc_stats(y1, y0)

                        if dec_ate:
                            if dec_ate_post:
                                y0, y1 = get_y0_y1(sess, y_ate, f0va, f1va, shape=yva.shape, L=100)
                            else:
                                y0, y1 = get_y0_y1(sess, y_ate, {t_ph: [[0]]}, {t_ph: [[1]]}, shape=(1, 1), L=100)
                            y0, y1 = y0 * ys + ym, y1 * ys + ym
                            ate_va = evaluator_validation.abs_ate(y1, y0)

                            if dec_ate_post:
                                y0, y1 = get_y0_y1(sess, y_ate, f0te, f1te, shape=yte.shape, L=100)
                            else:
                                y0, y1 = get_y0_y1(sess, y_ate, {t_ph: [[0]]}, {t_ph: [[1]]}, shape=(1, 1), L=100)
                            y0, y1 = y0 * ys + ym, y1 * ys + ym
                            ate_test = evaluator_test.abs_ate(y1, y0)

                            print("Epoch: {}/{}, log p(x) >= {:0.3f}, \n"
                                  "ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, \n"
                                  # "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, \n"
                                  "ite_te: {:0.3f}, ate_te: {:0.3f}/{:0.3f}, pehe_te: {:0.3f}, \n"
                                  "ite_va: {:0.3f}, ate_va: {:0.3f}/{:0.3f}, pehe_va: {:0.3f}, \n"
                                  "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss,
                                                       score_train[0], score_train[1], score_train[2],
                                                       # rmses_train[0], rmses_train[1],
                                                       score_test[0], score_test[1], ate_test, score_test[2],
                                                       score_val[0], score_val[1], ate_va, score_val[2],
                                                       time.time() - t0))
                        else:
                            print("Epoch: {}/{}, log p(x) >= {:0.3f}, \n"
                                  "ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, \n"
                                  # "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, \n"
                                  "ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, \n"
                                  "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                                  "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss,
                                                       score_train[0], score_train[1], score_train[2],
                                                       # rmses_train[0], rmses_train[1],
                                                       score_test[0], score_test[1], score_test[2],
                                                       score_val[0], score_val[1], score_val[2],
                                                       time.time() - t0))

                        if dec_ate:
                            if ate_va < best_atevalid:
                                print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_atevalid,
                                                                                                     ate_va))
                                best_atevalid = ate_va
                                saver.save(sess, train_dir + 'model')
                        else:
                            if score_val[2] < best_pehevalid:

                                print('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_pehevalid,
                                                                                                     score_val[2]))
                                best_pehevalid = score_val[2]
                                saver.save(sess, train_dir + 'model')


            saver.restore(sess, train_dir + 'model')

            y0, y1 = get_y0_y1(sess, y_ite, f0tr, f1tr, shape=ytr.shape, L=100 if not reweighting else 1)
            y0, y1 = y0 * ys + ym, y1 * ys + ym
            score = evaluator_train.calc_stats(y1, y0)
            scores[i, :] = score

            if dec_ate:
                if dec_ate_post:
                    y0, y1 = get_y0_y1(sess, y_ate, f0tr, f1tr, shape=ytr.shape, L=100)
                else:
                    y0, y1 = get_y0_y1(sess, y_ate, {t_ph: [[0]]}, {t_ph: [[1]]}, shape=(1, 1), L=100)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                ate = evaluator_train.abs_ate(y1, y0)
                scores[i][1] = ate

            y0t, y1t = get_y0_y1(sess, y_ite, f0te, f1te, shape=yte.shape, L=100 if not reweighting else 1)
            y0t, y1t = y0t * ys + ym, y1t * ys + ym
            score_test = evaluator_test.calc_stats(y1t, y0t)
            scores_test[i, :] = score_test

            if dec_ate:
                if dec_ate_post:
                    y0, y1 = get_y0_y1(sess, y_ate, f0te, f1te, shape=yte.shape, L=100)
                else:
                    y0, y1 = get_y0_y1(sess, y_ate, {t_ph: [[0]]}, {t_ph: [[1]]}, shape=(1, 1), L=100)
                y0, y1 = y0 * ys + ym, y1 * ys + ym
                ate_test = evaluator_test.abs_ate(y1, y0)
                scores_test[i][1] = ate_test

            if not dec_ate:
                print('Replication: {}/{}, tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}' \
                      ', te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}'.format(i + 1, args.reps,
                                                                                    score[0], score[1], score[2],
                                                                                    score_test[0], score_test[1], score_test[2]))
            else:
                print('Replication: {}/{}'
                      ', tr_ite: {:0.3f}, tr_ate: {:0.3f}/{:0.3f}, tr_pehe: {:0.3f}' \
                      ', te_ite: {:0.3f}, te_ate: {:0.3f}/{:0.3f}, te_pehe: {:0.3f}'.format(i + 1,
                                                                                            args.reps,
                                                                                            score[0], score[1], ate, score[2],
                                                                                            score_test[0],
                                                                                            score_test[1],
                                                                                            ate_test,
                                                                                            score_test[2]
                                                                                            ))
            sess.close()

    print('CEVAE model total scores')
    means, stds = np.mean(scores, axis=0), sem(scores, axis=0)
    printf('train ITE: {:.3f}+-{:.3f}, train ATE: {:.3f}+-{:.3f}, train PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    means, stds = np.mean(scores_test, axis=0), sem(scores_test, axis=0)
    printf('test ITE: {:.3f}+-{:.3f}, test ATE: {:.3f}+-{:.3f}, test PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    printf(all_args)
    printf(non_defaults)
    return scores, scores_test


run(
    # d=200,
    # nh_t=2,
    # equiv_confounder=True,
    # equiv_conf_covar = True,
    # direct_conf_covar = True,
    # dim_conf=2,
    # # num_z_conf = 10,
    # decoder_switch = False,
    # decoder_sigma = True,
    first_switch=False,
    first_repre=False,
    # qy_sigma=True,
    # second_switch = False,
    # second_repre=False,
    # qz_sigma=False,
    # balanced_sample = True,
    # imp_weighting=True,
    # eval_weighting=True,
    # ipm_matching = True,
    # enc_alpha=100,
    # dec_matching=True,
    # dec_alpha=10000,
    # qz_matching=True,
    # qz_alpha=0,
    enc_obs_t=True,
    # dec_obs_t=True,
    # dec_x=False,
    # enc_x=False,
    # enc_t=False,
    # dec_t=False,
    no_vae=True,
    # infer_qy=False,
    infer_y=False,
    eval_qy=True,
    # dec_ate=True,
    # dec_ate_post=True,
    validate='pehe',
    # validate_step=5,
    # epochs=1,
    # index=12,
)
