#!/usr/bin/env python

import os
import edward as ed
import tensorflow as tf
# from sacred import Experiment
from copy import deepcopy
import time
import inspect

# ex = Experiment('cevae mods')

# from tensorflow.contrib.distributions import Normal
from edward.models import *
from progressbar import ETA, Bar, Percentage, ProgressBar
import matplotlib.pyplot as plt

from functools import partial

from datasets import *
from evaluation import Evaluator
import numpy as np
import time
from scipy.stats import sem

from argparse import ArgumentParser

from gcn.layers import GraphConvolution
from gcn.utils import preprocess_features, preprocess_adj

import pickle

parser = ArgumentParser()
parser.add_argument('-reps', type=int, default=100)
# parser.add_argument('-earl', type=int, default=10)
parser.add_argument('-lr', type=float, default=0.001)
parser.add_argument('-opt', choices=['adam', 'adamax'], default='adam')
# parser.add_argument('-epochs', type=int, default=100)
parser.add_argument('-print_every', type=int, default=10)
args = parser.parse_args()

args.true_post = True

# data_gen = IHDP(replications=args.reps)

is_networked = False
is_networked = True
data_gen = Pokec(replications=1, confounder='age')

if is_networked:
    if False:
        pokec_network = pokec.network_data()

        adjacency_list = pokec_network.adjacency_list
        neighbours_list = adjacency_list.neighbours
        offsets_list = adjacency_list.offsets

        import collections

        adjacency_dict = collections.defaultdict(list)
        for vertex in range(len(offsets_list)-1):
            adjacency_dict[vertex] = neighbours_list[offsets_list[vertex]:offsets_list[vertex+1]]

        adjacency_dict[len(offsets_list)-1] = neighbours_list[offsets_list[-1]:]

        import networkx as nx

        adj = nx.adjacency_matrix(nx.from_dict_of_lists(adjacency_dict))

        adj = preprocess_adj(adj)

        with open('adj.pickle', 'wb') as handle:
            pickle.dump(adj, handle, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        with open('adj.pickle', 'rb') as handle:
            adj = pickle.load(handle)


# exit(0)
# ATE = []
# for (train, valid, test, contfeats, binfeats) in data_gen.get_train_valid_test():
#     if len(train) == 3:
#         (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr), ztr = train
#         (xva, tva, yva), (y_cfva, mu0va, mu1va), zva = valid
#         (xte, tte, yte), (y_cfte, mu0te, mu1te), zte = test
#     else:
#         (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
#         (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
#         (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
#
#     evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)
#     evaluator_validation = Evaluator(yva, tva, y_cf=y_cfva, mu0=mu0va, mu1=mu1va)
#     evaluator_train = Evaluator(ytr, ttr, y_cf=y_cftr, mu0=mu0tr, mu1=mu1tr)
#     ATE.append(evaluator_train.true_ate)
#
#
# plt.hist(ATE, 50)
# plt.savefig('hist_.png')
#
# exit(0)

# for i in range(100):
#     artificial.model(i)
# quit()

dimx = 25

M = None  # batch size during training

lamba = 1e-4  # weight decay
nh, h = 3, 200  # number and size of hidden layers
gcn_dim = 50

from utils import config_nnbuilder
# config_nnbuilder(initializer=partial(initializers.xavier_initializer, uniform=False))
config_nnbuilder()
from utils import *

# @ex.config
# def cfg():
#     equiv_confounder = True
#     equiv_conf_covar = True
#     # direct_conf_covar = False
#     # num_z_conf = 10
#     decoder_switch = False
#     decoder_sigma = True
#     # second_switch = False
#     # balanced_sample = True
#     ipm_matching = True


# @ex.automain
def run(train_dir='./',#'/Users/abel/Downloads/CEVAE/models/ihdp/',
        use_gpu=False,
        proxy_noise_level=.2,
        gen_t=True,
        enc_t=True,
        trans_t=lambda x: x,
        cond_x=False,
        normed_z=False,
        suff_z=False,
        zt_num=1,
        zy_num=1,
        zt_y=False,
        zy_t=False,
        z_var_sigma=True,
        z_cond_t=False,
        metrics_sampled=False,
        y_sampled=False,
        nh_rep=3,
        nh_reg=3,
        nh_enc=3,
        nh_t=3,
        try_models=False,
        validate_models='elbo',
        weight_sigma=1,
        var_sigma=True,
        dec_switch=False,
        enc_switch=False,
        matching=None,
        alpha=100,
        beta=1,
        n_sample_vld=100,
        n_sample_eval=100,
        validate='pehe',
        va_rate=0.0,
        validate_step=10,
        stop_after=np.inf,
        epochs=100,
        n_repeat_sample=10,
        index=None,
        rand_seed=1,
        ):

    # [for k, v in locals().items()]
    all_args = locals()
    print(all_args)

    global ARGSPEC
    ARGSPEC = inspect.getfullargspec(run)
    non_defaults = {arg: all_args[arg] for arg, default
                    in zip(ARGSPEC.args, ARGSPEC.defaults)
                    if default != all_args[arg]
                    and arg != 'trans_t'
                    and not isinstance(all_args[arg], (tf.Tensor, list))}

    # TODO: backup locals()

    train_dir += time.strftime("%Y-%m-%d %H:%M:%S ") + str(non_defaults).replace(':', '-') + '/'
    if not os.path.exists(train_dir):
        print("make savefolder: {0:s}...".format(train_dir))
        os.makedirs(train_dir)  # Make folder

    def printf(*args, filename=train_dir + 'output.txt'):
        print(*args)
        with open(filename, 'a') as f:
            print(*args, file=f)

    scores = np.empty((0, 3))
    scores_test = np.empty((0, 3))

    art_config = yes_backdoor._replace(proxy_noise_level=proxy_noise_level)

    # data_gen = LinearArtificial(replications=50, samples=500, config=art_config)
    # data_gen = LinearArtificial(replications=5, samples=500, config=yes_backdoor)

    for i, (train, valid, test, contfeats, binfeats) in enumerate(data_gen.get_train_valid_test(index)):
        # continue
        i = index or i
        print('\nReplication {}/{}'.format(i + 1, args.reps))
        if len(train) == 3:
            (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr), ztr = train
            (xva, tva, yva), (y_cfva, mu0va, mu1va), zva = valid
            (xte, tte, yte), (y_cfte, mu0te, mu1te), zte = test
        else:
            (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
            (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
            (xte, tte, yte), (y_cfte, mu0te, mu1te) = test

        evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)
        evaluator_validation = Evaluator(yva, tva, y_cf=y_cfva, mu0=mu0va, mu1=mu1va)
        evaluator_train = Evaluator(ytr, ttr, y_cf=y_cftr, mu0=mu0tr, mu1=mu1tr)
        printf('true ATE tr', evaluator_train.true_ate)
        printf('true ATE va', evaluator_validation.true_ate)
        printf('true ATE te', evaluator_test.true_ate)

        # reorder features with binary first and continuous after
        perm = binfeats + contfeats
        xtr, xva, xte = xtr[:, perm], xva[:, perm], xte[:, perm]

        # xalltr, talltr, yalltr = np.concatenate([xtr, xva], axis=0), np.concatenate([ttr, tva], axis=0), np.concatenate([ytr, yva], axis=0)
        # evaluator_train = Evaluator(yalltr, talltr, y_cf=np.concatenate([y_cftr, y_cfva], axis=0),
        #                             mu0=np.concatenate([mu0tr, mu0va], axis=0), mu1=np.concatenate([mu1tr, mu1va], axis=0))

        # zero mean, unit variance for y during training
        ym, ys = np.mean(ytr), np.std(ytr)
        ytr, yva, yte = (ytr - ym) / ys, (yva - ym) / ys, (yte - ym) / ys

        # normalize
        # from sklearn.preprocessing import StandardScaler
        xtr, xva, xte = tuple(StandardScaler().fit_transform(data) for data in (xtr, xva, xte))  # , ytr, yva, yte


        # symmestric t
        # ttr = ttr.astype(np.float32)
        # ttr -= 0.5
        # tva = tva.astype(np.float32)
        # tva -= 0.5
        # tte = tte.astype(np.float32)
        # tte -= 0.5

        best_model_result = None
        for zt_y, zy_t in ((False, False), (False, True), (True, False), (True, True)) if try_models else ((zt_y, zy_t),):

            with tf.Graph().as_default(), tf.device("/gpu:0" if use_gpu else "/cpu:0"):

                # if rand_seed is not None:
                #     ed.set_seed(rand_seed)
                #     np.random.seed(rand_seed)
                #     tf.set_random_seed(rand_seed)

                ed.set_seed(i)
                np.random.seed(i)
                tf.set_random_seed(i)

                sess = tf.InteractiveSession()

                x_ph_bin = tf.placeholder(tf.float32, [M, len(binfeats)], name='x_bin')  # binary inputs
                x_ph_cont = tf.placeholder(tf.float32, [M, len(contfeats)], name='x_cont')  # continuous inputs
                t_ph = tf.placeholder(tf.float32, [M, 1], name='t_ph')
                y_ph = tf.placeholder(tf.float32, [M, 1], name='y_ph')

                x_ph = tf.concat([x_ph_bin, x_ph_cont], 1)
                activation = tf.nn.elu

                # conditional prior p(z|x)
                # can be seen as a representation learner
                if not z_cond_t:
                    if is_networked:

                        # features = preprocess_features(sdfasdf)
                        # support = [preprocess_adj(adj)]

                        placeholders = {
                            'support': [tf.sparse_placeholder(tf.float32)],
                            'features': tf.placeholder(tf.float32, (None, xtr.shape[1])),
                            # 'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
                            # 'labels_mask': tf.placeholder(tf.int32),
                            'dropout': tf.placeholder_with_default(0., shape=()),
                            'num_features_nonzero': tf.placeholder(tf.int32)  # helper variable for sparse dropout
                        }

                        gcn_layers = [
                            GraphConvolution(input_dim=xtr.shape[1],
                                             output_dim=gcn_dim,  # might need larger
                                             placeholders=placeholders,
                                             act=tf.nn.relu,
                                             dropout=True,
                                             sparse_inputs=False),
                            GraphConvolution(input_dim=gcn_dim,
                                             output_dim=gcn_dim,
                                             placeholders=placeholders,
                                             act=lambda x: x,
                                             dropout=True)
                        ]
                        repre = x_ph
                        for l in gcn_layers:
                            repre = l(repre)

                        mu_pz_x, sigma_pz_x = mu_sigma(repre, [200], zt_num + zy_num, var_sigma=z_var_sigma,
                                                       switch=False, name='pz_x')
                        z = Normal(loc=mu_pz_x, scale=sigma_pz_x, name='pz_networked')
                    else:
                        mu_pz_x, sigma_pz_x = mu_sigma(x_ph, nh_rep * [h], zt_num + zy_num, var_sigma=z_var_sigma, switch=False, name='pz_x')
                        z = Normal(loc=mu_pz_x, scale=sigma_pz_x, name='pz_x')
                else:
                    z = normal(x_ph, nh_rep * [h], zt_num + zy_num, var_sigma=z_var_sigma, switch=False, t=trans_t(t_ph), name='pz_xt')

                zt = z[:, :zt_num]
                zy = z[:, zt_num:]
                if suff_z:
                    zy_sq = tf.square(zy)
                    zy = tf.concat([zy, zy_sq], 1)

                # @@ DECODER
                def switch(inp, on):
                    return inp[:, :(inp.shape[1] if on else 0)]

                # p(t|z)
                # if gen_t:
                #     # TODO: adjust depth
                #     logits = fc_net(tf.concat([zt, switch(zy, zy_t), switch(x_ph, cond_x)], 1),
                #                     nh_t * [h], [[1, None]], 'pt_z')
                #     t = Bernoulli(logits=logits, dtype=tf.float32)
                # else:
                #     # eval indep
                #     logits = fc_net(x_ph, [h], [[1, None]], 'pt_z')


                # p(y|t,z)
                # mu_py, sigma_py = mu_sigma(tf.concat([t, z if z0_y else zy], 1), nh*[h], 1, var_sigma=var_sigma, switch=False, name='py')
                # y = Normal(loc=mu_py, scale=py_sigma*sigma_py)
                def build_y(weight_sigma):
                    reweight = tf.Variable(.1) if weight_sigma == 'trainable' else weight_sigma

                    y = normal(tf.concat([zy, switch(zt, zt_y), switch(x_ph, cond_x)], 1),
                               nh_reg * [h], 1, var_sigma, dec_switch, trans_t(t if gen_t else t_ph),  # name='py',
                               reweight=reweight)

                    return y

                if isinstance(weight_sigma, list):
                    y = [build_y(w) for w in weight_sigma]
                else:
                    y = build_y(weight_sigma)


                # @@ ENCODER
                # mu_qz, sigma_qz = mu_sigma(tf.concat([t_ph, y_ph, x_ph], 1), nh * [h], z0_num + z1_num, var_sigma=True, switch=False, name='qz')
                # qz = Normal(loc=mu_qz, scale=sigma_qz)

                def normal_no_t(inp, name):
                    mu_qz, sigma_qz = mu_sigma(inp, nh_enc * [h], zt_num + zy_num,
                                               var_sigma=True, switch=False, name=name)
                    return Normal(loc=mu_qz, scale=sigma_qz)

                def sel_by(inp, sel, value):
                    return tf.gather_nd(inp, tf.cast(tf.equal(sel, value), tf.int32))

                inp = tf.concat([y_ph, x_ph], 1)

                if enc_t:
                    qz = normal(inp, nh_enc * [h], zt_num + zy_num, True, enc_switch, trans_t(t_ph), name='qz')

                    # qz1 = normal_no_t(sel_by(inp, t_ph, 1.), 'qz1')
                    # qz0 = normal_no_t(sel_by(inp, t_ph, 0.), 'qz0')

                else:
                    qz = normal_no_t(inp, 'qz')

                is_train = tf.placeholder(tf.bool, name="is_train")

                # def normed_RV(qz):
                #     from tensorflow.python.layers.normalization import BatchNormalization
                #     bn = BatchNormalization(center=False, scale=False)
                #     bn.apply(qz, training=is_train)
                #
                #     mean, var = tf.cond(
                #         is_train,
                #         lambda: tf.nn.moments(qz, axes=[0]),
                #         lambda: (bn.moving_mean, bn.moving_variance)
                #     )
                #
                #     scale = tf.rsqrt(var + 1e-3)
                #     offset = -mean*scale
                #
                #     from tensorflow.contrib.distributions.python.ops.bijectors import Affine
                #     # from tensorflow.contrib.linalg.python.ops.linear_operator_diag import LinearOperatorDiag
                #
                #     import tensorflow.contrib.distributions as ds
                #
                #     QzDist = ds.TransformedDistribution(
                #         distribution=qz,
                #         bijector=Affine(shift=offset, scale_identity_multiplier=scale)
                #     )
                #
                #     from edward.models import RandomVariable
                #
                #     class QzRV(RandomVariable, ds.TransformedDistribution):
                #         def __init__(self, *args, **kwargs):
                #             super(QzRV, self).__init__(*args, **kwargs)
                #
                #         def _log_prob(self, value):
                #             return QzDist._log_prob(value)
                #
                #         def _sample_n(self, n, seed=None):
                #             return QzDist._sample_n(n, seed)
                #
                #         def _mean(self):
                #             return qz_orig.mean()*scale + offset
                #
                #     # QzRV = type('QzRV', (RandomVariable, QzDist), {})
                #     qz_orig = qz
                #     return QzRV(distribution=QzDist)
                #
                # if normed_z:
                #     qz = normed_RV(qz)


                # for plot
                mu_qz_predict = [None]
                mu_z_predict = [None]
                qz_predict = [None]
                z_predict = [None]

                t_assign = tf.placeholder(tf.float32, [M, 1], name='t_assign')
                # workaround a bug (or maybe I did something wrong)
                #   that there is a mysterious dependence on t_ph when sample y directly
                # instead we first sample z or pz (or their mean) and then plug data into y, this should work anyway
                z_ph = tf.placeholder(tf.float32, [M, 1], name='z_ph')

                # TODO: test on ihdp
                def train_proc(y, model_name='model'):

                    def predict_y(x, yf=None, t=None, use_mean=True, sample_n=1, in_sample=True):
                        z_sampled = qz if in_sample else z

                        if gen_t:
                            # NOTE: 只有 placeholder 可以通过 copy 替换为数据
                            muy = ed.copy(y.loc, {z: qz, t: t_assign} if gen_t else {z: z_sampled}, scope='muy')
                        else:
                            # sampling instead of using mean if we want to exp on data with larger noise
                            # 对每个 x 要 sample 多次取平均
                            # NOTE: placeholder 可以通过 copy 替换为数据，但数据量不同时有 bug
                            replace_dict = {t_ph: t_assign, z: z_ph}
                            mu0p = ed.copy(y.loc, replace_dict, scope='mu0_post')
                            mu1p = ed.copy(y.loc, replace_dict, scope='mu1_post')

                        # qz 中直接 assign t 是不行的，因为 recover 依赖 factual t 提供的信息
                        # muy = ed.copy(y.loc, {z: qz.mean()}, scope='muy_post')

                        fd = {
                            x_ph_bin: x[:, 0:len(binfeats)], x_ph_cont: x[:, len(binfeats):],
                            # x_ph: x,
                            y_ph: yf, t_ph: t,
                            is_train: False  # for batch norm
                        }

                        if is_networked:
                            # feed_dict.update({placeholders['labels']: labels})
                            # feed_dict.update({placeholders['labels_mask']: labels_mask})
                            fd.update({placeholders['features']: x_train})
                            fd.update({placeholders['support'][0]: adj})
                            # leave dropout to 0 when not training

                        if not enc_t:
                            del fd[t_ph]

                        mu_qz_predict[0] = sess.run(qz.mean(), feed_dict=fd)
                        qz_predict[0] = sess.run(qz, feed_dict=fd)

                        if not in_sample:
                            del fd[y_ph]
                            del fd[t_ph]

                        mu_z_predict[0] = sess.run(z.mean(), feed_dict=fd)
                        z_predict[0] = sess.run(z, feed_dict=fd)

                        # fd[t_ph] = np.ones((xte.shape[0], 1))
                        # y1 = sess.run(ed.copy(y.loc, {z: tf.constant(mu_qz_predict[0], dtype=tf.float32)}), feed_dict=fd)
                        #
                        # fd[t_ph] = np.zeros((xte.shape[0], 1))
                        # y0 = sess.run(ed.copy(y.loc, {z: tf.constant(mu_qz_predict[0], dtype=tf.float32)}), feed_dict=fd)

                        if gen_t:
                            fd[t_assign] = np.zeros((t.shape[0], 1))
                            y0 = sess.run(muy, feed_dict=fd)
                            fd[t_assign] = np.ones((t.shape[0], 1))
                            y1 = sess.run(muy, feed_dict=fd)
                        else:
                            # if use_mean:
                            #     fd[z_ph] = mu_qz_predict[0] if in_sample else mu_z_predict[0]
                            #
                            #     fd[t_assign] = np.zeros((t.shape[0], 1))
                            #     y0 = sess.run(mu0p, feed_dict=fd)
                            #     fd[t_assign] = np.ones((t.shape[0], 1))
                            #     y1 = sess.run(mu1p, feed_dict=fd)
                            # else:
                            #     fd[t_assign] = np.zeros((x.shape[0], 1))
                            #     y0 = np.mean(np.hstack([sess.run(mu0p, feed_dict=fd) for _ in range(sample_n)]), axis=1, keepdims=True)
                            #     fd[t_assign] = np.ones((x.shape[0], 1))
                            #     y1 = np.mean(np.hstack([sess.run(mu1p, feed_dict=fd) for _ in range(sample_n)]), axis=1, keepdims=True)

                            y0_sample, y1_sample = [], []
                            for _ in range(sample_n):
                                # NOTE: it is meaningless to set sample_n > 1 with use_mean
                                if use_mean:
                                    fd[z_ph] = mu_qz_predict[0] if in_sample else mu_z_predict[0]
                                else:
                                    fd[z_ph] = sess.run(qz, feed_dict=fd) if in_sample else sess.run(z, feed_dict=fd)

                                fd[t_assign] = np.zeros((x.shape[0], 1))
                                y0_sample.append(sess.run(mu0p, feed_dict=fd))
                                fd[t_assign] = np.ones((x.shape[0], 1))
                                y1_sample.append(sess.run(mu1p, feed_dict=fd))

                            y0 = np.mean(np.hstack(y0_sample), axis=1, keepdims=True)
                            y1 = np.mean(np.hstack(y1_sample), axis=1, keepdims=True)

                        y0, y1 = y0 * ys + ym, y1 * ys + ym

                        return y0, y1

                    def metrics(x, yf, t, use_mean=True, sample_n=1):
                        fd = {x_ph_bin: x[:, 0:len(binfeats)],
                              x_ph_cont: x[:, len(binfeats):],
                              t_ph: t, y_ph: yf, is_train: False
                              }

                        if is_networked:
                            # feed_dict.update({placeholders['labels']: labels})
                            # feed_dict.update({placeholders['labels_mask']: labels_mask})
                            fd.update({placeholders['features']: x_train})
                            fd.update({placeholders['support'][0]: adj})

                        def tensors(use_mean):
                            qz_or_mean = qz.mean() if use_mean else qz

                            # if we want the exact version, we can NOT just remove .mean(), need to sample qz
                            y_post_eval = ed.copy(y, {z: qz_or_mean, t: t_ph} if gen_t else {z: qz_or_mean},
                                                  scope='y_post_eval')
                            if gen_t:
                                t_post_eval = ed.copy(t, {z: qz_or_mean}, scope='t_post_eval')

                            logp = tf.reduce_mean(
                                tf.reduce_sum((y_post_eval.log_prob(y_ph) + t_post_eval.log_prob(
                                    t_ph)) if gen_t else y_post_eval.log_prob(y_ph), axis=1) +
                                # TODO: test beta!
                                beta * tf.reduce_sum(z.log_prob(qz_or_mean) - qz.log_prob(qz_or_mean), axis=1)
                            )

                            KL = tf.reduce_mean(  # TODO: should be the same w or wo reduce_mean
                                tf.reduce_sum(z.log_prob(qz_or_mean) - qz.log_prob(qz_or_mean), axis=1)
                            )

                            return logp, KL

                        logp, KL = tensors(use_mean)

                        logp_v = np.mean([sess.run(logp, feed_dict=fd) for _ in range(sample_n)])

                        KL_v = np.mean([sess.run(KL, feed_dict=fd) for _ in range(sample_n)])

                        return logp_v, KL_v

                    # Create data dictionary for edward
                    data = {y: y_ph, t: t_ph} if gen_t else {y: y_ph}
                    # 所有 inference 应该都在 inference 对象里自动完成了
                    #inference = ed.KLqp({z: qz}, data)
                    inference = ed.KLqp(data=data)
                    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
                    # optimizer = tf.train.MomentumOptimizer(learning_rate=0.0000001, momentum=.9, use_nesterov=True)
                    inference.initialize(optimizer=optimizer, kl_scaling={z: beta})

                    loss_op = inference.loss

                    if matching is not None:
                        matched = {
                            # TODO: support zt_num
                            'z': z,
                            # 'mu_z': mu_pz_x,
                        }[matching]
                        h_rep_norm = matched / safe_sqrt(tf.reduce_sum(tf.square(matched), axis=1, keep_dims=True))
                        imb_distance, _ = wasserstein(h_rep_norm, t_ph, 0.5, lam=10.0, its=10, sq=False,
                                                      backpropT=True)
                        p_alpha = (tf.Variable(1.) if alpha == 'trainable' else alpha)
                        loss_op += p_alpha * imb_distance

                    if is_networked:
                        for gcn_layer in gcn_layers:
                            for var in gcn_layer.vars.values():
                                loss_op += 1e-5 * tf.nn.l2_loss(var)

                    # for batch norm
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    with tf.control_dependencies(update_ops):
                        train_op = optimizer.minimize(loss_op)

                    saver = tf.train.Saver(tf.contrib.slim.get_variables())
                    tf.global_variables_initializer().run()

                    # save initial model
                    saver.save(sess, train_dir + model_name)

                    # t0te, t1te = np.zeros((xte.shape[0], 1)), np.ones((xte.shape[0], 1))
                    # f1te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t1te}
                    # f0te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t0te}
                    # t0va, t1va = np.zeros((xva.shape[0], 1)), np.ones((xva.shape[0], 1))
                    # f1va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t1va}
                    # f0va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t0va}
                    # t0tr, t1tr = np.zeros((xtr.shape[0], 1)), np.ones((xtr.shape[0], 1))
                    # f1tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t1tr}
                    # f0tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t0tr}


                    # z_sample = [sess.run(z, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                    #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                    # z_sample = np.hstack(z_sample)
                    # t_sample = [sess.run(t, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                    #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                    # t_sample = np.hstack(t_sample)
                    #
                    # import dcor
                    # dcor_zt = np.mean([dcor.distance_correlation(z, t) for z, t in zip(z_sample, t_sample)])
                    # print('init dcor', dcor_zt)

                    count_for_stop = 0
                    n_epoch, n_iter_per_epoch, idx = epochs, n_repeat_sample*int(xtr.shape[0] / 100), np.arange(xtr.shape[0])
                    if is_networked:
                        n_epoch, n_iter_per_epoch = 100, 1

                    best_logpvalid = - np.inf
                    logpvalid = - np.inf
                    best_pehevalid = np.inf
                    best_atevalid = np.inf

                    plt.figure()

                    for epoch in range(n_epoch):
                        avg_loss = 0.0

                        t0 = time.time()
                        widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
                        pbar = ProgressBar(n_iter_per_epoch, widgets=widgets)
                        pbar.start()
                        np.random.shuffle(idx)
                        for j in range(n_iter_per_epoch):
                            pbar.update(j)

                            # networked data use GCN, which trains on the full data (cannot trained by mini-batch)
                            batch = np.random.choice(idx, 100) if not is_networked else idx
                            x_train, y_train, t_train = xtr[batch], ytr[batch], ttr[batch]

                            # if ipm_matching:

                            feed_dict = {x_ph_bin: x_train[:, 0:len(binfeats)],
                                                                    x_ph_cont: x_train[:, len(binfeats):],
                                                                    t_ph: t_train, y_ph: y_train, is_train: True}

                            if is_networked:
                                # feed_dict.update({placeholders['labels']: labels})
                                # feed_dict.update({placeholders['labels_mask']: labels_mask})
                                feed_dict.update({placeholders['features']: x_train})
                                feed_dict.update({placeholders['support'][0]: adj})
                                # GCN use a separate dropout rate
                                feed_dict.update({placeholders['dropout']: 0.1})

                            _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict)
                            avg_loss += loss
                            # else:
                            #     info_dict = inference.update(feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                            #                                             x_ph_cont: x_train[:, len(binfeats):],
                            #                                             t_ph: t_train, y_ph: y_train})
                            #     avg_loss += info_dict['loss']

                        avg_loss = avg_loss / n_iter_per_epoch
                        avg_loss = avg_loss / (len(idx) if is_networked else 100)

                        if np.isnan(avg_loss):
                            break

                        if is_networked or epoch % args.print_every == 0:
                            # y0, y1 = get_y0_y1(sess, y_ite, f0tr, f1tr, shape=ytr.shape, L=n_sample_vld)
                            # TODO: on IHDP, Within-sample error is computed over both the training and validation sets
                            y0, y1 = predict_y(xtr, ytr, ttr, not y_sampled, n_sample_vld)
                            # plt.figure()
                            # plt.plot(mu_qz_predict[0][ttr == 0.], ztr[ttr == 0.], '.')
                            # plt.plot(mu_qz_predict[0][ttr == 1.], ztr[ttr == 1.], '.')
                            # plt.savefig(train_dir + 'zz%d_e%d_tr.png' % (i, epoch))
                            # plt.clf()

                            score_train = evaluator_train.calc_stats(y1, y0)

                            # rmses_train = evaluator_train.y_errors(y0, y1)

                            if not is_networked:
                                y0, y1 = predict_y(xva, yva, tva, not y_sampled, n_sample_vld)
                                score_val = evaluator_validation.calc_stats(y1, y0)

                                y0, y1 = predict_y(xte, yte, tte, not y_sampled, n_sample_vld, in_sample=False)
                                score_test = evaluator_test.calc_stats(y1, y0)

                            if len(train) == 3 or is_networked:
                                if is_networked:
                                    zte = data_gen.conf_data
                                    tte = ttr

                                plt.plot(y0[tte == 0.], mu_qz_predict[0][tte == 0.], '.')
                                plt.plot(y1[tte == 1.], mu_qz_predict[0][tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_yzf' % (i, epoch) + model_name + '_te.png')
                                plt.clf()
                                plt.plot(y0[tte == 1.], mu_qz_predict[0][tte == 1.], '.')
                                plt.plot(y1[tte == 0.], mu_qz_predict[0][tte == 0.], '.')
                                plt.savefig(train_dir + '%d_e%d_yzcf' % (i, epoch) + model_name + '_te.png')
                                plt.clf()

                                # plt.figure()
                                plt.plot(mu_qz_predict[0][tte == 0.], zte[tte == 0.], '.')
                                plt.plot(mu_qz_predict[0][tte == 1.], zte[tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_qtz' % (i, epoch) + model_name + '_te.png')
                                plt.clf()

                                # plt.figure()
                                plt.plot(mu_qz_predict[0][tte == 0.], qz_predict[0][tte == 0.], '.')
                                plt.plot(mu_qz_predict[0][tte == 1.], qz_predict[0][tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_qz' % (i, epoch) + model_name + '_te.png')
                                plt.clf()

                                # plt.figure()
                                # plt.plot(mu_z_predict[0][tte == 0.], zte[tte == 0.], '.')
                                # plt.plot(mu_z_predict[0][tte == 1.], zte[tte == 1.], '.')
                                # plt.savefig(train_dir + 'pz_tz_%d_e%d' % (i, epoch) + model_name + '_te.png')
                                # plt.clf()

                                # plt.figure()
                                plt.plot(mu_z_predict[0][tte == 0.], mu_qz_predict[0][tte == 0.], '.')
                                plt.plot(mu_z_predict[0][tte == 1.], mu_qz_predict[0][tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_pqz' % (i, epoch) + model_name + '_te.png')
                                plt.clf()

                                # plt.figure()
                                plt.plot(mu_z_predict[0][tte == 0.], z_predict[0][tte == 0.], '.')
                                plt.plot(mu_z_predict[0][tte == 1.], z_predict[0][tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_z' % (i, epoch) + model_name + '_te.png')
                                plt.clf()

                                # plt.figure()
                                # plt.plot(mu_z_predict[0][tte == 0.], xte[tte == 0.], '.')
                                # plt.plot(mu_z_predict[0][tte == 1.], xte[tte == 1.], '.')
                                # plt.savefig(train_dir + 'pz_tz_%d_e%d' % (i, epoch) + model_name + '_te.png')
                                # plt.clf()

                            # construct a deterministic version (i.e. use the mean of the approximate posterior) of the lower bound
                            # for early stopping according to a validation set

                            logptr, KLtr = metrics(xtr, ytr, ttr, not metrics_sampled, n_sample_vld)

                            fd = {x_ph_bin: xtr[:, 0:len(binfeats)],
                                                                        x_ph_cont: xtr[:, len(binfeats):],
                                                                        t_ph: ttr, y_ph: ytr, is_train: False}

                            if is_networked:
                                # feed_dict.update({placeholders['labels']: labels})
                                # feed_dict.update({placeholders['labels_mask']: labels_mask})
                                fd.update({placeholders['features']: x_train})
                                fd.update({placeholders['support'][0]: adj})

                            imbdtr = sess.run(imb_distance, feed_dict=fd)

                            if not is_networked:
                                logpvalid, KLvalid = metrics(xva, yva, tva, not metrics_sampled, n_sample_vld)

                                imbdvalid = sess.run(imb_distance, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                                                                              x_ph_cont: xva[:, len(binfeats):],
                                                                              t_ph: tva, y_ph: yva, is_train: False})

                            # TODO: high dim z
                            # z_sample = [sess.run(z, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                            #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                            # z_sample = np.hstack(z_sample)
                            # t_sample = [sess.run(t, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                            #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                            # t_sample = np.hstack(t_sample)
                            #
                            # import dcor
                            # dcor_zt = np.mean([dcor.distance_correlation(z, t) for z, t in zip(z_sample, t_sample)])
                            # print('dcor', dcor_zt)

                            if is_networked:
                                printf("Epoch: {}/{}, loss >= {:0.3f}, \n"
                                      "ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, \n"
                                      # "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, \n"
                                      "ELBO: tr {:0.3f}, va {:0.3f} \n"
                                       "KL: tr {:0.3f}, va {:0.3f} \n"
                                       "imbd: tr {:0.3f}, va {:0.3f} \n"
                                      "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss,
                                                           score_train[0], score_train[1], score_train[2],
                                                           # rmses_train[0], rmses_train[1],
                                                           logptr, logpvalid,
                                                           KLtr, 0,
                                                           imbdtr, 0,
                                                       time.time() - t0))
                            else:
                                printf("Epoch: {}/{}, loss >= {:0.3f}, \n"
                                      "ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, \n"
                                      # "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, \n"
                                      "ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, \n"
                                      "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                                      "ELBO: tr {:0.3f}, va {:0.3f} \n"
                                       "KL: tr {:0.3f}, va {:0.3f} \n"
                                       "imbd: tr {:0.3f}, va {:0.3f} \n"
                                      "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss,
                                                           score_train[0], score_train[1], score_train[2],
                                                           # rmses_train[0], rmses_train[1],
                                                           score_test[0], score_test[1], score_test[2],
                                                           score_val[0], score_val[1], score_val[2],
                                                           logptr, logpvalid,
                                                           KLtr, KLvalid,
                                                           imbdtr, imbdvalid,
                                                       time.time() - t0))

                            if weight_sigma == 'trainable':
                                printf('weight_sigma: %f' % sess.run(reweight))
                            # printf(sess.run(reweight))
                            if alpha == 'trainable':
                                printf('alpha: %f' % sess.run(p_alpha))

                            save_model = False
                            if best_logpvalid == -np.inf or \
                                    (logpvalid > best_logpvalid and (logpvalid-best_logpvalid)/np.abs(best_logpvalid) > va_rate):
                                printf('Improved validation elbo, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid,
                                                                                                      logpvalid))
                                best_logpvalid = logpvalid
                                if validate == 'elbo':
                                    save_model = True

                            if not is_networked and score_val[2]/best_pehevalid < 1 - va_rate:
                                printf('Improved validation pehe, old: {:0.3f}, new: {:0.3f}'.format(best_pehevalid,
                                                                                                      score_val[2]))
                                best_pehevalid = score_val[2]
                                if validate == 'pehe':
                                    save_model = True

                        if epoch % validate_step == 0 or epoch == (n_epoch - 1):
                            if save_model:
                                saver.save(sess, train_dir + model_name)
                                count_for_stop = 0
                            else:
                                count_for_stop += 1

                            if count_for_stop == stop_after:
                                break

                    if validate is not None:
                        saver.restore(sess, train_dir + model_name)

                    y0, y1 = predict_y(xtr, ytr, ttr, not y_sampled, n_sample_eval)
                    score_train = evaluator_train.calc_stats(y1, y0)

                    y0, y1 = predict_y(xte, yte, tte, not y_sampled, n_sample_eval, in_sample=False)
                    score_test = evaluator_test.calc_stats(y1, y0)

                    y0, y1 = predict_y(xva, yva, tva, not y_sampled, n_sample_eval)
                    score_val = evaluator_validation.calc_stats(y1, y0)

                    if len(train) == 3:
                        # plt.figure()
                        plt.plot(mu_qz_predict[0][ttr == 0.], ztr[ttr == 0.], '.')
                        plt.plot(mu_qz_predict[0][ttr == 1.], ztr[ttr == 1.], '.')
                        plt.savefig(train_dir + str(i) + model_name + '_tr.png')
                        plt.clf()

                        # plt.figure()
                        plt.plot(mu_qz_predict[0][tte == 0.], zte[tte == 0.], '.')
                        plt.plot(mu_qz_predict[0][tte == 1.], zte[tte == 1.], '.')
                        plt.savefig(train_dir + str(i) + model_name + '_te.png')
                        plt.clf()

                    # model_results.append((score_train, score_val, score_test, best_logpvalid))
                    result = (score_train, score_val, score_test, best_logpvalid, zt_y, zy_t,)
                    printf('\n@@@@@@@@ Replication: {}/{}/{}, \n'
                           'tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}, \n'
                           'te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}, \n'
                           "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                           'ELBO: {:0.3f}\n\n'.format(i + 1, zt_y, zy_t,
                                                      score_train[0], score_train[1], score_train[2],
                                                      score_test[0], score_test[1], score_test[2],
                                                      score_val[0], score_val[1], score_val[2],
                                                      best_logpvalid))

                    # release plt memory
                    plt.close('all')

                    import gc
                    gc.collect()

                    return result

                if isinstance(y, list):
                    results = [train_proc(y_, 'model%d' % i) for i, y_ in enumerate(y)]
                    elbos = [result[-3] for result in results]
                    best_elbo_id = elbos.index(max(elbos))
                    result = results[best_elbo_id]

                    printf('#\n#\nbest elbo model', best_elbo_id)

                    pehe_val = [result[1][2] for result in results]
                    best_pehe_id = pehe_val.index(min(pehe_val))

                    printf('#\n#\nbest pehe model', best_pehe_id)
                else:
                    result = train_proc(y)

                sess.close()
                if try_models:
                    # save best model result
                    def is_better(result):
                        if validate_models == 'pehe':
                            return result[1][2] < best_model_result[1][2]
                        elif validate_models == 'elbo':
                            return result[3] > best_model_result[3]

                    if best_model_result is None or is_better(result):
                        best_model_result = result

        if try_models:
            score_train, score_val, score_test, logpvalid, zt_y, zy_t = best_model_result
            printf('\n\n@@@@@@@@@@@@@@@@ Best: {}/{}/{}, \n'
                   'tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}, \n'
                   'te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}, \n'
                   "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                   'ELBO: {:0.3f}\n\n\n\n'.format(i + 1, zt_y, zy_t,
                                                  score_train[0], score_train[1], score_train[2],
                                                  score_test[0], score_test[1], score_test[2],
                                                  score_val[0], score_val[1], score_val[2],
                                                  logpvalid))
        # save result of best model
        # scores[i, :] = result[0]  # score_train
        # scores_test[i, :] = result[2]  # score_test
        scores = np.vstack((scores, result[0]))
        scores_test = np.vstack((scores_test, result[2]))

    print('CEVAE model total scores')
    means, stds = np.mean(scores, axis=0), np.std(scores, axis=0)
    printf('train ITE: {:.3f}+-{:.3f}, train ATE: {:.3f}+-{:.3f}, train PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    means, stds = np.mean(scores_test, axis=0), np.std(scores_test, axis=0)
    printf('test ITE: {:.3f}+-{:.3f}, test ATE: {:.3f}+-{:.3f}, test PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    # TODO: backup locals()
    # printf(all_args)
    printf(non_defaults)
    if hasattr(data_gen, 'config'):
        printf(data_gen.config)

    return scores, scores_test


from itertools import product

[run(
    # try_models=True,
    # validate_models='pehe',
    # proxy_noise_level=nl,
    zt_num=0,
    zy_num=1,
    # z_var_sigma=False,
    metrics_sampled=True,
    y_sampled=True,
    gen_t=False,
    # enc_t=False,
    trans_t=lambda x: (x - .5),
    # cond_x=True,
    # nh_t=1,
    # nh_rep=3,
    # nh_reg=6,
    # zt_y=True,
    # zy_t=True,
    weight_sigma=1, #[0.01*5**i for i in range(5)],  # 'trainable',
    # var_sigma=False,
    validate='elbo',
    # va_rate=0.001,
    n_sample_vld=100,
    n_sample_eval=100,
    # validate_step=1,
    stop_after=10,
    epochs=300,
    # n_repeat_sample=3,
    # index=9,
    # dec_switch=True,
    # enc_switch=True,
    matching='z',
    alpha=0,  # 'trainable',
    # beta=1.2,
    # rand_seed=None,
)
    for _ in range(1)]
    # for nl in [.4, .6, .8, 1]]
    # for beta in [1.1 + 0.4*i for i in range(5)]]
    # for alpha, scale in product((1e10, 1e9, 1e8, 1e7, 1e6), (1., )
    #                           (0.2, 0.5, 1, 2, 5)
    #                             )
# ]
