#!/usr/bin/env python

import os
import edward as ed
import tensorflow as tf
# from sacred import Experiment
from copy import deepcopy
import time
import inspect

# ex = Experiment('cevae mods')

# from tensorflow.contrib.distributions import Normal
from edward.models import *
from progressbar import ETA, Bar, Percentage, ProgressBar
import matplotlib.pyplot as plt

from functools import partial

from datasets import *

from evaluation import Evaluator
import numpy as np
import time
from scipy.stats import sem

from argparse import ArgumentParser

from gcn.layers import GraphConvolution
from gcn.utils import preprocess_features, preprocess_adj

import pickle

parser = ArgumentParser()
parser.add_argument('-reps', type=int, default=100)
# parser.add_argument('-earl', type=int, default=10)
parser.add_argument('-lr', type=float, default=0.001)
parser.add_argument('-opt', choices=['adam', 'adamax'], default='adam')
# parser.add_argument('-epochs', type=int, default=100)
parser.add_argument('-print_every', type=int, default=10)
args = parser.parse_args()

args.true_post = True

from datasets import data_path
is_networked = False
# is_networked = True
# data_gen = Pokec(replications=10, confounder='age_cat')

if is_networked:
    if False:
        pokec_network = pokec.network_data()

        adjacency_list = pokec_network.adjacency_list
        neighbours_list = adjacency_list.neighbours
        offsets_list = adjacency_list.offsets

        import collections

        adjacency_dict = collections.defaultdict(list)
        for vertex in range(len(offsets_list)-1):
            adjacency_dict[vertex] = neighbours_list[offsets_list[vertex]:offsets_list[vertex+1]]

        adjacency_dict[len(offsets_list)-1] = neighbours_list[offsets_list[-1]:]

        import networkx as nx

        adj = nx.adjacency_matrix(nx.from_dict_of_lists(adjacency_dict))

        adj = preprocess_adj(adj)

        with open('adj.pickle', 'wb') as handle:
            pickle.dump(adj, handle, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        with open(data_path + 'Pokec/adj.pickle', 'rb') as handle:
            adj = pickle.load(handle)


# exit(0)
# ATE = []
# for (train, valid, test, contfeats, binfeats) in data_gen.get_train_valid_test():
#     if len(train) == 3:
#         (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr), ztr = train
#         (xva, tva, yva), (y_cfva, mu0va, mu1va), zva = valid
#         (xte, tte, yte), (y_cfte, mu0te, mu1te), zte = test
#     else:
#         (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
#         (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
#         (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
#
#     evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)
#     evaluator_validation = Evaluator(yva, tva, y_cf=y_cfva, mu0=mu0va, mu1=mu1va)
#     evaluator_train = Evaluator(ytr, ttr, y_cf=y_cftr, mu0=mu0tr, mu1=mu1tr)
#     ATE.append(evaluator_train.true_ate)
#
#
# plt.hist(ATE, 50)
# plt.savefig('hist_.png')
#
# exit(0)

# for i in range(100):
#     artificial.model(i)
# quit()

dimx = 25

M = None  # batch size during training

lamba = 1e-4  # weight decay
nh, h = 3, 200  # number and size of hidden layers
gcn_dim = 100

from utils import config_nnbuilder
# config_nnbuilder(initializer=partial(initializers.xavier_initializer, uniform=False))
config_nnbuilder()
from utils import *

# @ex.config
# def cfg():
#     equiv_confounder = True
#     equiv_conf_covar = True
#     # direct_conf_covar = False
#     # num_z_conf = 10
#     decoder_switch = False
#     decoder_sigma = True
#     # second_switch = False
#     # balanced_sample = True
#     ipm_matching = True


# @ex.automain
def run(train_dir='./',#''/Users/abel/Downloads/CEVAE/models/ihdp/',
        use_gpu=False,
        art_config=None,
        gen_t=False,
        enc_t=True,
        trans_t=lambda x: x,
        cond_x=False,
        normed_z=False,
        suff_z=False,
        zt_num=0,
        zy_num=1,
        zt_y=False,
        zy_t=False,
        z_weight_sigma=1,
        z_var_sigma=True,
        z_cond_t=False,
        z_switch=False,
        metrics_sampled=False,
        y_sampled=False,
        nh_rep=3,
        nh_reg=3,
        nh_enc=3,
        nh_t=3,
        try_models=False,
        validate_models='elbo',
        weight_sigma=1,
        var_sigma=True,
        dec_switch=False,
        enc_switch=False,
        matching=None,
        alpha=100,
        beta=1,
        n_sample_vld=100,
        n_sample_eval=100,
        slim_eval=False,
        validate='elbo',
        va_rate=.0,
        validate_step=10,
        stop_after=np.inf,
        epochs=100,
        n_repeat_sample=10,
        index=None,
        rand_seed=1,
        regression_only=False,
        ):

    # [for k, v in locals().items()]
    all_args = locals()
    print(all_args)

    global ARGSPEC
    ARGSPEC = inspect.getfullargspec(run)
    non_defaults = {arg: all_args[arg] for arg, default
                    in zip(ARGSPEC.args, ARGSPEC.defaults)
                    if default != all_args[arg]
                    and arg != 'trans_t'
                    and not isinstance(all_args[arg], (tf.Tensor, list, ArtConfig))}

    # TODO: backup locals()

    train_dir += time.strftime("%Y-%m-%d %H:%M:%S ") + str(non_defaults).replace(':', '-') + '/'
    if not os.path.exists(train_dir):
        print("make savefolder: {0:s}...".format(train_dir))
        os.makedirs(train_dir)  # Make folder

    def printf(*args, filename=train_dir + 'output.txt'):
        print(*args)
        with open(filename, 'a') as f:
            print(*args, file=f)

    scores = np.empty((0, 3))
    scores_test = np.empty((0, 3))

    data_gen = LinearArtificial(replications=100, samples=500, config=art_config)
    # data_gen = LinearArtificial(replications=5, samples=500, config=yes_backdoor)

    for i, (train, valid, test, contfeats, binfeats) in enumerate(data_gen.get_train_valid_test(index)):
        # continue
        i = index or i
        print('\nReplication {}/{}'.format(i + 1, args.reps))
        if len(train) == 3:
            (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr), ztr = train
            (xva, tva, yva), (y_cfva, mu0va, mu1va), zva = valid
            (xte, tte, yte), (y_cfte, mu0te, mu1te), zte = test
        else:
            (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
            (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
            (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
        if isinstance(data_gen, Pokec):
            def build_evaluator(mask):
                idx = data_gen.mask==mask
                return Evaluator(yte[idx], tte[idx], y_cf=y_cfte[idx], mu0=mu0te[idx], mu1=mu1te[idx])

            evaluator_test = build_evaluator(2)
            evaluator_validation = build_evaluator(1)
            evaluator_train = build_evaluator(0)
        else:
            evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)
            evaluator_validation = Evaluator(yva, tva, y_cf=y_cfva, mu0=mu0va, mu1=mu1va)
            evaluator_train = Evaluator(ytr, ttr, y_cf=y_cftr, mu0=mu0tr, mu1=mu1tr)
        printf('true ATE tr', evaluator_train.true_ate)
        printf('true ATE va', evaluator_validation.true_ate)
        printf('true ATE te', evaluator_test.true_ate)

        # reorder features with binary first and continuous after
        perm = binfeats + contfeats
        xtr, xva, xte = xtr[:, perm], xva[:, perm], xte[:, perm]

        # xalltr, talltr, yalltr = np.concatenate([xtr, xva], axis=0), np.concatenate([ttr, tva], axis=0), np.concatenate([ytr, yva], axis=0)
        # evaluator_train = Evaluator(yalltr, talltr, y_cf=np.concatenate([y_cftr, y_cfva], axis=0),
        #                             mu0=np.concatenate([mu0tr, mu0va], axis=0), mu1=np.concatenate([mu1tr, mu1va], axis=0))

        # zero mean, unit variance for y_gathered during training
        ym, ys = np.mean(ytr), np.std(ytr)
        ytr, yva, yte = (ytr - ym) / ys, (yva - ym) / ys, (yte - ym) / ys

        # normalize
        # from sklearn.preprocessing import StandardScaler
        xtr, xva, xte = tuple(StandardScaler().fit_transform(data) for data in (xtr, xva, xte))  # , ytr, yva, yte


        # symmestric t
        # ttr = ttr.astype(np.float32)
        # ttr -= 0.5
        # tva = tva.astype(np.float32)
        # tva -= 0.5
        # tte = tte.astype(np.float32)
        # tte -= 0.5

        best_model_result = None
        for zt_y, zy_t in ((False, False), (False, True), (True, False), (True, True)) if try_models else ((zt_y, zy_t),):

            with tf.Graph().as_default(), tf.device("/gpu:0" if use_gpu else "/cpu:0"):

                # if rand_seed is not None:
                #     ed.set_seed(rand_seed)
                #     np.random.seed(rand_seed)
                #     tf.set_random_seed(rand_seed)

                ed.set_seed(i)
                np.random.seed(i)
                tf.set_random_seed(i)

                sess = tf.InteractiveSession()

                x_ph_bin = tf.placeholder(tf.float32, [M, len(binfeats)], name='x_bin')  # binary inputs
                x_ph_cont = tf.placeholder(tf.float32, [M, len(contfeats)], name='x_cont')  # continuous inputs
                t_ph = tf.placeholder(tf.float32, [M, 1], name='t_ph')
                y_ph = tf.placeholder(tf.float32, [M, 1], name='y_ph')
                if isinstance(data_gen, Pokec):
                    mask_ph = tf.placeholder(tf.int8, [M, ], name='mask_ph')

                x_ph = tf.concat([x_ph_bin, x_ph_cont], 1)
                activation = tf.nn.elu

                # conditional prior p(z|x)
                # can be seen as a representation learner
                if not z_cond_t:
                    if is_networked:

                        # features = preprocess_features(sdfasdf)
                        # support = [preprocess_adj(adj)]

                        placeholders = {
                            'support': [tf.sparse_placeholder(tf.float32)],
                            'features': tf.placeholder(tf.float32, (None, xtr.shape[1])),
                            # 'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
                            # 'labels_mask': tf.placeholder(tf.int32),
                            'dropout': tf.placeholder_with_default(0., shape=()),
                            'num_features_nonzero': tf.placeholder(tf.int32)  # helper variable for sparse dropout
                        }

                        gcn_layers = [
                            GraphConvolution(input_dim=xtr.shape[1],
                                             output_dim=gcn_dim,  # might need larger
                                             placeholders=placeholders,
                                             act=tf.nn.relu,
                                             dropout=True,
                                             sparse_inputs=False),
                            GraphConvolution(input_dim=gcn_dim,
                                             output_dim=gcn_dim,
                                             placeholders=placeholders,
                                             act=lambda x: x,
                                             dropout=True)
                        ]
                        repre = x_ph
                        for l in gcn_layers:
                            repre = l(repre)

                        mu_pz_x, sigma_pz_x = mu_sigma(repre, [50, 20, 10], zt_num + zy_num, var_sigma=z_var_sigma,
                                                       switch=False, name='pz_x')
                        z = Normal(loc=mu_pz_x, scale=z_weight_sigma*sigma_pz_x, name='pz_networked')
                    else:
                        mu_pz_x, sigma_pz_x = mu_sigma(x_ph, nh_rep * [h], zt_num + zy_num, var_sigma=z_var_sigma, switch=False, name='pz_x')
                        z = Normal(loc=mu_pz_x, scale=z_weight_sigma*sigma_pz_x, name='pz_x')
                else:
                    z = normal(x_ph, nh_rep * [h], zt_num + zy_num, var_sigma=z_var_sigma, switch=z_switch,
                               t=t_ph if z_switch else trans_t(t_ph), reweight=z_weight_sigma, name='pz_xt')

                zt = z[:, :zt_num]
                zy = z[:, zt_num:]
                if suff_z:
                    zy_sq = tf.square(zy)
                    zy = tf.concat([zy, zy_sq], 1)

                # @@ DECODER
                def switch(inp, on):
                    return inp[:, :(inp.shape[1] if on else 0)]

                # p(t|z)
                # if gen_t:
                #     # TODO: adjust depth
                #     logits = fc_net(tf.concat([zt, switch(zy, zy_t), switch(x_ph, cond_x)], 1),
                #                     nh_t * [h], [[1, None]], 'pt_z')
                #     t = Bernoulli(logits=logits, dtype=tf.float32)
                # else:
                #     # eval indep
                #     logits = fc_net(x_ph, [h], [[1, None]], 'pt_z')


                # p(y_gathered|t,z)
                # mu_py, sigma_py = mu_sigma(tf.concat([t, z if z0_y else zy], 1), nh*[h], 1, var_sigma=var_sigma, switch=False, name='py')
                # y_gathered = Normal(loc=mu_py, scale=py_sigma*sigma_py)
                # def build_y(weight_sigma):
                #     reweight = tf.Variable(.1) if weight_sigma == 'trainable' else weight_sigma
                #
                #     y = normal(tf.concat([zy, switch(zt, zt_y), switch(x_ph, cond_x)], 1),
                #                nh_reg * [h], 1, var_sigma, dec_switch, trans_t(t if gen_t else t_ph),  # name='py',
                #                reweight=reweight)
                #
                #     return y
                #
                # if isinstance(weight_sigma, list):
                #     y = [build_y(w) for w in weight_sigma]
                # else:
                #     y = build_y(weight_sigma)

                reweight = tf.Variable(.1) if weight_sigma == 'trainable' else weight_sigma

                # NOTE: tran_t cannot be used with switch
                y = normal(tf.concat([zy, switch(zt, zt_y), switch(x_ph, cond_x)], 1),
                           nh_reg * [h], 1, var_sigma, dec_switch, trans_t(t if gen_t else t_ph),  # name='py',
                           reweight=reweight, name='y')

                # @@ ENCODER
                # mu_qz, sigma_qz = mu_sigma(tf.concat([t_ph, y_ph, x_ph], 1), nh * [h], z0_num + z1_num, var_sigma=True, switch=False, name='qz')
                # qz = Normal(loc=mu_qz, scale=sigma_qz)

                def normal_no_t(inp, name):
                    mu_qz, sigma_qz = mu_sigma(inp, nh_enc * [h], zt_num + zy_num,
                                               var_sigma=True, switch=False, name=name)
                    return Normal(loc=mu_qz, scale=sigma_qz)

                # def sel_by(inp, sel, value):
                #     return tf.gather_nd(inp, tf.cast(tf.equal(sel, value), tf.int32))

                inp = tf.concat([y_ph, x_ph], 1)

                if enc_t:
                    qz = normal(inp, nh_enc * [h], zt_num + zy_num, True, enc_switch,
                                t_ph if enc_switch else trans_t(t_ph), name='qz')

                    # qz1 = normal_no_t(sel_by(inp, t_ph, 1.), 'qz1')
                    # qz0 = normal_no_t(sel_by(inp, t_ph, 0.), 'qz0')

                else:
                    qz = normal_no_t(inp, 'qz')

                if regression_only:
                    z = Normal(z.loc, tf.zeros_like(z))
                    qz = z
                    y = Normal(y.loc, 0.001*tf.ones_like(y))

                is_train = tf.placeholder(tf.bool, name="is_train")

                if isinstance(data_gen, Pokec):
                    def mask_equal(values):
                        pred = False
                        for m in values:
                            pred = tf.logical_or(tf.equal(mask_ph, m), pred)
                        return pred

                    def gather_normal(rv, pred):
                        return Normal(tf.gather(rv.loc, tf.where(pred)[:, -1]),
                                      tf.gather(rv.scale, tf.where(pred)[:, -1]),
                                      name=rv.name + '_gathered')

                    def gather(tensor, pred):
                        return tf.gather(tensor, tf.where(pred)[:, -1])

                    # idx = tf.cond(is_train, lambda: tf.equal(mask_ph, 0), lambda: tf.greater_equal(mask_ph, 0))
                    # qz = Normal(tf.where(idx, qz.loc, z.loc), tf.where(idx, qz.scale, z.scale),
                    #             name='qz_switch')

                def tensors(mask, gathered=False):
                    if gathered:
                        return gather_normal(z, mask_equal(mask)), gather_normal(qz, mask_equal(mask)), \
                               gather_normal(y, mask_equal(mask)), gather(y_ph, mask_equal(mask))
                    else:
                        return z, qz, y, y_ph

                def normed_RV(qz):
                    from tensorflow.python.layers.normalization import BatchNormalization
                    bn = BatchNormalization(center=False, scale=False)
                    bn.apply(qz, training=is_train)

                    mean, var = tf.cond(
                        is_train,
                        lambda: tf.nn.moments(qz, axes=[0]),
                        lambda: (bn.moving_mean, bn.moving_variance)
                    )

                    scale = tf.rsqrt(var + 1e-3)
                    offset = -mean*scale

                    from tensorflow.contrib.distributions.python.ops.bijectors import Affine
                    # from tensorflow.contrib.linalg.python.ops.linear_operator_diag import LinearOperatorDiag

                    import tensorflow.contrib.distributions as ds

                    QzDist = ds.TransformedDistribution(
                        distribution=qz,
                        bijector=Affine(shift=offset, scale_identity_multiplier=scale)
                    )
                #
                #     from edward.models import RandomVariable
                #
                #     class QzRV(RandomVariable, ds.TransformedDistribution):
                #         def __init__(self, *args, **kwargs):
                #             super(QzRV, self).__init__(*args, **kwargs)
                #
                #         def _log_prob(self, value):
                #             return QzDist._log_prob(value)
                #
                #         def _sample_n(self, n, seed=None):
                #             return QzDist._sample_n(n, seed)
                #
                #         def _mean(self):
                #             return qz_orig.mean()*scale + offset
                #
                #     # QzRV = type('QzRV', (RandomVariable, QzDist), {})
                #     qz_orig = qz
                #     return QzRV(distribution=QzDist)
                #
                # if normed_z:
                #     qz = normed_RV(qz)

            # def train_proc(y, y_ph, 'model'='model'):

                # Create data dictionary for edward
                z_gathered, qz_gathered, y_gathered, y_ph_gathered = tensors((0,), gathered=is_networked)
                data = {y_gathered: y_ph_gathered}  # t: t_ph is omitted since gen_t is never used
                inference = ed.KLqp({z_gathered: qz_gathered}, data) if not regression_only else ed.KLqp(data=data)

                optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
                # optimizer = tf.train.MomentumOptimizer(learning_rate=0.0000001, momentum=.9, use_nesterov=True)
                inference.initialize(optimizer=optimizer, kl_scaling={z: beta})

                loss_op = inference.loss

                if matching is not None:
                    matched = {
                        # TODO: support zt_num
                        'z': z,
                        # 'mu_z': mu_pz_x,
                    }[matching]
                    h_rep_norm = matched / safe_sqrt(tf.reduce_sum(tf.square(matched), axis=1, keep_dims=True))
                    imb_distance, _ = wasserstein(h_rep_norm, t_ph, 0.5, lam=10.0, its=10, sq=False,
                                                  backpropT=True)
                    p_alpha = (tf.Variable(1.) if alpha == 'trainable' else alpha)
                    loss_op += p_alpha * imb_distance

                if is_networked:
                    for gcn_layer in gcn_layers:
                        for var in gcn_layer.vars.values():
                            loss_op += 1e-5 * tf.nn.l2_loss(var)

                # for batch norm
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    train_op = optimizer.minimize(loss_op)

                saver = tf.train.Saver(tf.contrib.slim.get_variables())
                tf.global_variables_initializer().run()

                # save initial model
                saver.save(sess, train_dir + 'model')

                def metrics(x, yf, t, use_mean=True, sample_n=1, slim_eval=False, mask=(0,)):
                    fd = {x_ph_bin: x[:, 0:len(binfeats)],
                          x_ph_cont: x[:, len(binfeats):],
                          t_ph: t, y_ph: yf, is_train: False
                          }

                    if is_networked:
                        # feed_dict.update({placeholders['labels']: labels})
                        # feed_dict.update({placeholders['labels_mask']: labels_mask})
                        fd.update({placeholders['features']: x_train})
                        fd.update({placeholders['support'][0]: adj})
                    if isinstance(data_gen, Pokec):
                        fd[mask_ph] = data_gen.mask

                    z_gathered, qz_gathered, y_gathered, y_ph_gathered = \
                        tensors(mask, gathered=isinstance(data_gen, Pokec))

                    def build_tensors(use_mean):
                        qz_or_mean = qz.mean() if use_mean else qz
                        if isinstance(data_gen, Pokec):
                            qz_or_mean_gathered = qz_gathered.mean() if use_mean else qz_gathered
                        else:
                            qz_or_mean_gathered = qz_or_mean

                        # if we want the exact version, we can NOT just remove .mean(), need to sample qz
                        y_post_eval = ed.copy(y_gathered, {z: qz_or_mean, t: t_ph} if gen_t else {z: qz_or_mean},
                                              scope='y_post_eval')
                        if gen_t:
                            t_post_eval = ed.copy(t, {z: qz_or_mean}, scope='t_post_eval')

                        reconst_err = tf.reduce_mean(
                            tf.reduce_sum((y_post_eval.log_prob(y_ph_gathered) + t_post_eval.log_prob(
                                t_ph)) if gen_t else y_post_eval.log_prob(y_ph_gathered), axis=1)
                        )

                        KL = tf.reduce_mean(  # should be the same w or wo reduce_mean if use_mean
                            tf.reduce_sum(z_gathered.log_prob(qz_or_mean_gathered) -
                                          qz_gathered.log_prob(qz_or_mean_gathered), axis=1)
                        )

                        return reconst_err, KL

                    reconst_err, KL = build_tensors(use_mean)

                    if regression_only:
                        logp = reconst_err
                        logp_v = np.mean([sess.run(logp, feed_dict=fd) for _ in range(sample_n)])
                        return logp_v, np.nan, np.nan
                    else:
                        if slim_eval:
                            logp = reconst_err + beta * KL
                            logp_v = np.mean([sess.run(logp, feed_dict=fd) for _ in range(sample_n)])
                            return logp_v, np.nan, np.nan
                        else:
                            logp = reconst_err + beta * KL
                            logp_v = np.mean([sess.run(logp, feed_dict=fd) for _ in range(sample_n)])
                            KL_v = np.mean([sess.run(KL, feed_dict=fd) for _ in range(sample_n)])
                            if not is_networked:
                                imbd_v = sess.run(imb_distance, feed_dict=fd)

                            return logp_v, KL_v, (np.nan if is_networked else imbd_v)

                mu_qz_predict = [None]
                mu_z_predict = [None]
                qz_predict = [None]
                z_predict = [None]

                t_assign = tf.placeholder(tf.float32, [M, 1], name='t_assign')
                # workaround a bug (or maybe I did something wrong)
                #   that there is a mysterious dependence on t_ph when sample y directly
                # instead we first sample z or pz (or their mean) and then plug data into y, this should work anyway
                # NOTE: somehow fixed afterwards
                z_ph = tf.placeholder(tf.float32, [M, 1], name='z_ph')

                def predict_y(x, yf=None, t=None, use_mean=True, sample_n=1, in_sample=True, mask=(0,)):
                    fd = {
                        x_ph_bin: x[:, 0:len(binfeats)], x_ph_cont: x[:, len(binfeats):],
                        is_train: False,
                    }

                    out_sample_fd = dict(fd)
                    if z_cond_t:
                        out_sample_fd[t_ph] = t

                    in_sample_fd = dict(fd)
                    in_sample_fd[y_ph] = yf
                    if enc_t:
                        in_sample_fd[t_ph] = t

                    if is_networked:
                        # feed_dict.update({placeholders['labels']: labels})
                        # feed_dict.update({placeholders['labels_mask']: labels_mask})
                        fd.update({placeholders['features']: x_train})
                        fd.update({placeholders['support'][0]: adj})
                        # leave dropout to 0 when not training
                    if isinstance(data_gen, Pokec):
                        fd[mask_ph] = data_gen.mask

                    z_gathered, qz_gathered, y_gathered, y_ph_gathered = \
                        tensors(mask, gathered=isinstance(data_gen, Pokec))

                    z_sampled = qz if in_sample else z
                    if gen_t:
                        # NOTE: 只有 placeholder 可以通过 copy 替换为数据
                        muy = ed.copy(y_gathered.loc, {z: qz, t: t_assign} if gen_t else {z: z_sampled}, scope='muy')
                    else:
                        # 对每个 x 要 sample 多次取平均
                        # NOTE: placeholder 可以通过 copy 替换为数据，但数据量不同时有 bug
                        replace_dict = {t_ph: t_assign}  # , z: z_ph
                        mu0p = ed.copy(y_gathered.loc, replace_dict, scope='mu0_post')
                        mu1p = ed.copy(y_gathered.loc, replace_dict, scope='mu1_post')

                    # qz 中直接 assign t 是不行的，因为 recover 依赖 factual t 提供的信息
                    # muy = ed.copy(y_gathered.loc, {z: qz.mean()}, scope='muy_post')

                    # if not enc_t:
                    #     del fd[t_ph]
                    mu_qz_predict[0] = sess.run(qz.mean(), feed_dict=in_sample_fd)
                    qz_predict[0] = sess.run(qz, feed_dict=in_sample_fd)

                    mu_z_predict[0] = sess.run(z.mean(), feed_dict=out_sample_fd)
                    z_predict[0] = sess.run(z, feed_dict=out_sample_fd)

                    fd = in_sample_fd if in_sample else out_sample_fd

                    # fd[t_ph] = np.ones((xte.shape[0], 1))
                    # y1 = sess.run(ed.copy(y_gathered.loc, {z: tf.constant(mu_qz_predict[0], dtype=tf.float32)}), feed_dict=fd)
                    #
                    # fd[t_ph] = np.zeros((xte.shape[0], 1))
                    # y0 = sess.run(ed.copy(y_gathered.loc, {z: tf.constant(mu_qz_predict[0], dtype=tf.float32)}), feed_dict=fd)

                    if gen_t:
                        fd[t_assign] = np.zeros((t.shape[0], 1))
                        y0 = sess.run(muy, feed_dict=fd)
                        fd[t_assign] = np.ones((t.shape[0], 1))
                        y1 = sess.run(muy, feed_dict=fd)
                    else:
                        if use_mean:
                            fd[z_ph] = mu_qz_predict[0] if in_sample else mu_z_predict[0]

                            fd[t_assign] = np.zeros((t.shape[0], 1))
                            y0 = sess.run(mu0p, feed_dict=fd)
                            fd[t_assign] = np.ones((t.shape[0], 1))
                            y1 = sess.run(mu1p, feed_dict=fd)
                        else:
                            fd[t_assign] = np.zeros((t.shape[0], 1))
                            y0 = np.mean(np.hstack([sess.run(mu0p, feed_dict=fd) for _ in range(sample_n)]), axis=1, keepdims=True)
                            fd[t_assign] = np.ones((t.shape[0], 1))
                            y1 = np.mean(np.hstack([sess.run(mu1p, feed_dict=fd) for _ in range(sample_n)]), axis=1, keepdims=True)

                        # y0_sample, y1_sample = [], []
                        # for _ in range(sample_n):
                        #     # NOTE: it is meaningless to set sample_n > 1 with use_mean
                        #     if use_mean:
                        #         fd[z_ph] = mu_qz_predict[0] if in_sample else mu_z_predict[0]
                        #     else:
                        #         fd[z_ph] = sess.run(qz, feed_dict=fd) if in_sample else sess.run(z, feed_dict=fd)
                        #
                        #     fd[t_assign] = np.zeros((x.shape[0], 1))
                        #     y0_sample.append(sess.run(mu0p, feed_dict=fd))
                        #     fd[t_assign] = np.ones((x.shape[0], 1))
                        #     y1_sample.append(sess.run(mu1p, feed_dict=fd))
                        #
                        # y0 = np.mean(np.hstack(y0_sample), axis=1, keepdims=True)
                        # y1 = np.mean(np.hstack(y1_sample), axis=1, keepdims=True)

                    y0, y1 = y0 * ys + ym, y1 * ys + ym

                    return y0, y1

                # t0te, t1te = np.zeros((xte.shape[0], 1)), np.ones((xte.shape[0], 1))
                # f1te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t1te}
                # f0te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t0te}
                # t0va, t1va = np.zeros((xva.shape[0], 1)), np.ones((xva.shape[0], 1))
                # f1va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t1va}
                # f0va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t0va}
                # t0tr, t1tr = np.zeros((xtr.shape[0], 1)), np.ones((xtr.shape[0], 1))
                # f1tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t1tr}
                # f0tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t0tr}


                # z_sample = [sess.run(z, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                # z_sample = np.hstack(z_sample)
                # t_sample = [sess.run(t, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                # t_sample = np.hstack(t_sample)
                #
                # import dcor
                # dcor_zt = np.mean([dcor.distance_correlation(z, t) for z, t in zip(z_sample, t_sample)])
                # print('init dcor', dcor_zt)

                count_for_stop = 0
                n_epoch, n_iter_per_epoch, idx = epochs, n_repeat_sample*int(xtr.shape[0] / 100), np.arange(xtr.shape[0])
                if is_networked:
                    n_epoch, n_iter_per_epoch = 100, 1

                best_logpvalid = - np.inf
                logpvalid = - np.inf
                best_pehevalid = np.inf
                best_atevalid = np.inf

                plt.figure()

                for epoch in range(n_epoch):
                    avg_loss = 0.0

                    t0 = time.time()
                    widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
                    pbar = ProgressBar(n_iter_per_epoch, widgets=widgets)
                    pbar.start()
                    np.random.shuffle(idx)
                    for j in range(n_iter_per_epoch):
                        pbar.update(j)

                        # networked data use GCN, which trains on the full data (cannot trained by mini-batch)
                        batch = np.random.choice(idx, 100) if not is_networked else idx
                        x_train, y_train, t_train = xtr[batch], ytr[batch], ttr[batch]

                        # if ipm_matching:

                        feed_dict = {x_ph_bin: x_train[:, 0:len(binfeats)],
                                                                x_ph_cont: x_train[:, len(binfeats):],
                                                                t_ph: t_train, y_ph: y_train, is_train: True}

                        if is_networked:
                            # feed_dict.update({placeholders['labels']: labels})
                            # feed_dict.update({placeholders['labels_mask']: labels_mask})
                            feed_dict.update({placeholders['features']: x_train})
                            feed_dict.update({placeholders['support'][0]: adj})
                            # GCN use a separate dropout rate
                            feed_dict.update({placeholders['dropout']: 0.1})
                            feed_dict[mask_ph] = data_gen.mask

                        # logpvalid, KLvalid, imbdvalid = metrics(xte, yte, tte, not metrics_sampled, n_sample_vld,
                        #                                         mask=(2,))

                        _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict)
                        avg_loss += loss
                        # else:
                        #     info_dict = inference.update(feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                        #                                             x_ph_cont: x_train[:, len(binfeats):],
                        #                                             t_ph: t_train, y_ph: y_train})
                        #     avg_loss += info_dict['loss']

                    avg_loss = avg_loss / n_iter_per_epoch
                    avg_loss = avg_loss / (len(idx) if is_networked else 100)

                    if np.isnan(avg_loss):
                        printf('training diverged!\n')
                        break

                    if epoch % args.print_every == 0:
                        # TODO: on IHDP, Within-sample error is computed over both the training and validation sets

                        if slim_eval:
                            if is_networked:
                                # during training, only evaluate validation
                                pass
                            else:
                                # during training, evaluate validation
                                logpvalid, KLvalid, imbdvalid = metrics(xva, yva, tva, not metrics_sampled,
                                                                        n_sample_vld, mask=(1,))
                                # OPT in-sample estimation
                                pass
                        else:
                            logptr, KLtr, imbdtr = metrics(xtr, ytr, ttr, not metrics_sampled, n_sample_vld, mask=(0,))
                            y0, y1 = predict_y(xtr, ytr, ttr, not y_sampled, n_sample_vld, mask=(0,))
                            # plt.figure()
                            # plt.plot(mu_qz_predict[0][ttr == 0.], ztr[ttr == 0.], '.')
                            # plt.plot(mu_qz_predict[0][ttr == 1.], ztr[ttr == 1.], '.')
                            # plt.savefig(train_dir + 'zz%d_e%d_tr.png' % (i, epoch))
                            # plt.clf()

                            score_train = evaluator_train.calc_stats(y1, y0)
                            # rmses_train = evaluator_train.y_errors(y0, y1)

                            logpvalid, KLvalid, imbdvalid = metrics(xva, yva, tva, not metrics_sampled, n_sample_vld, mask=(1,))
                            # imbdvalid = sess.run(imb_distance, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                            #                                               x_ph_cont: xva[:, len(binfeats):],
                            #                                               t_ph: tva, y_ph: yva, is_train: False})

                            y0, y1 = predict_y(xva, yva, tva, not y_sampled, n_sample_vld, mask=(1,))
                            score_val = evaluator_validation.calc_stats(y1, y0)

                            y0, y1 = predict_y(xte, yte, tte, not y_sampled, n_sample_vld, in_sample=False, mask=(2,))
                            score_test = evaluator_test.calc_stats(y1, y0)

                        if (len(train) == 3 or isinstance(data_gen, Pokec)) and not slim_eval:
                            def plots(zte, tte, use_mean=True):
                                qz_or_mu = mu_qz_predict[0] if use_mean else qz_predict[0]
                                z_or_mu = mu_z_predict[0] if use_mean else z_predict[0]

                                plt.plot(y0[tte == 0.], qz_or_mu[tte == 0.], '.')
                                plt.plot(y1[tte == 1.], qz_or_mu[tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_yzf' % (i, epoch) + 'model' + '_te.png')
                                plt.clf()
                                plt.plot(y0[tte == 1.], qz_or_mu[tte == 1.], '.')
                                plt.plot(y1[tte == 0.], qz_or_mu[tte == 0.], '.')
                                plt.savefig(train_dir + '%d_e%d_yzcf' % (i, epoch) + 'model' + '_te.png')
                                plt.clf()

                                # plt.figure()
                                plt.plot(qz_or_mu[tte == 0.], zte[tte == 0.], '.')
                                plt.plot(qz_or_mu[tte == 1.], zte[tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_qtz' % (i, epoch) + 'model' + '_te.png')
                                plt.clf()

                                # plt.figure()
                                plt.plot(mu_qz_predict[0][tte == 0.], qz_predict[0][tte == 0.], '.')
                                plt.plot(mu_qz_predict[0][tte == 1.], qz_predict[0][tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_qz' % (i, epoch) + 'model' + '_te.png')
                                plt.clf()

                                # plt.figure()
                                # plt.plot(mu_z_predict[0][tte == 0.], zte[tte == 0.], '.')
                                # plt.plot(mu_z_predict[0][tte == 1.], zte[tte == 1.], '.')
                                # plt.savefig(train_dir + 'pz_tz_%d_e%d' % (i, epoch) + 'model' + '_te.png')
                                # plt.clf()

                                # plt.figure()
                                plt.plot(z_or_mu[tte == 0.], qz_or_mu[tte == 0.], '.')
                                plt.plot(z_or_mu[tte == 1.], qz_or_mu[tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_pqz' % (i, epoch) + 'model' + '_te.png')
                                plt.clf()

                                # plt.figure()
                                plt.plot(mu_z_predict[0][tte == 0.], z_predict[0][tte == 0.], '.')
                                plt.plot(mu_z_predict[0][tte == 1.], z_predict[0][tte == 1.], '.')
                                plt.savefig(train_dir + '%d_e%d_z' % (i, epoch) + 'model' + '_te.png')
                                plt.clf()

                                # plt.figure()
                                # plt.plot(mu_z_predict[0][tte == 0.], xte[tte == 0.], '.')
                                # plt.plot(mu_z_predict[0][tte == 1.], xte[tte == 1.], '.')
                                # plt.savefig(train_dir + 'pz_tz_%d_e%d' % (i, epoch) + 'model' + '_te.png')
                                # plt.clf()

                            if isinstance(data_gen, Pokec):
                                plots(data_gen.confounder_data(2), tte[data_gen.mask==2])
                            else:
                                plots(zte, tte)


                        # TODO: high dim z
                        # z_sample = [sess.run(z, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                        #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                        # z_sample = np.hstack(z_sample)
                        # t_sample = [sess.run(t, feed_dict={x_ph_bin: xva[:, 0:len(binfeats)],
                        #                                    x_ph_cont: xva[:, len(binfeats):]}) for _ in range(100)]
                        # t_sample = np.hstack(t_sample)
                        #
                        # import dcor
                        # dcor_zt = np.mean([dcor.distance_correlation(z, t) for z, t in zip(z_sample, t_sample)])
                        # print('dcor', dcor_zt)

                        if not slim_eval:
                            printf("Epoch: {}/{}, loss >= {:0.3f}, \n"
                                  "ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, \n"
                                  # "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, \n"
                                  "ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, \n"
                                  "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                                  "ELBO: tr {:0.3f}, va {:0.3f} \n"
                                   "KL: tr {:0.3f}, va {:0.3f} \n"
                                   "imbd: tr {:0.3f}, va {:0.3f} \n"
                                  "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss,
                                                       score_train[0], score_train[1], score_train[2],
                                                       # rmses_train[0], rmses_train[1],
                                                       score_test[0], score_test[1], score_test[2],
                                                       score_val[0], score_val[1], score_val[2],
                                                       logptr, logpvalid,
                                                       KLtr, KLvalid,
                                                       imbdtr, imbdvalid,
                                                   time.time() - t0))

                        if weight_sigma == 'trainable':
                            printf('weight_sigma: %f' % sess.run(reweight))
                        # printf(sess.run(reweight))
                        if alpha == 'trainable':
                            printf('alpha: %f' % sess.run(p_alpha))

                        save_model = False
                        if best_logpvalid == -np.inf or \
                                (logpvalid > best_logpvalid and (logpvalid-best_logpvalid)/np.abs(best_logpvalid) > va_rate):
                            printf('Improved validation elbo, old: {:0.3f}, new: {:0.3f}'.format(best_logpvalid,
                                                                                                  logpvalid))
                            best_logpvalid = logpvalid
                            if validate == 'elbo':
                                save_model = True

                        if validate == 'pehe' and not is_networked and score_val[2]/best_pehevalid < 1 - va_rate:
                            printf('Improved validation pehe, old: {:0.3f}, new: {:0.3f}'.format(best_pehevalid,
                                                                                                  score_val[2]))
                            best_pehevalid = score_val[2]
                            save_model = True

                    if epoch % validate_step == 0 or epoch == (n_epoch - 1):
                        if save_model:
                            saver.save(sess, train_dir + 'model')
                            count_for_stop = 0
                        else:
                            count_for_stop += 1

                        if count_for_stop == stop_after:
                            break

                if validate is not None:
                    saver.restore(sess, train_dir + 'model')

                y0, y1 = predict_y(xte, yte, tte, not y_sampled, n_sample_eval, in_sample=False, mask=(2,))
                score_test = evaluator_test.calc_stats(y1, y0)

                y0, y1 = predict_y(xtr, ytr, ttr, not y_sampled, n_sample_eval, mask=(0,))
                score_train = evaluator_train.calc_stats(y1, y0)

                y0, y1 = predict_y(xva, yva, tva, not y_sampled, n_sample_eval, mask=(1,))
                score_val = evaluator_validation.calc_stats(y1, y0)

                if len(train) == 3:
                    # plt.figure()
                    plt.plot(mu_qz_predict[0][ttr == 0.], ztr[ttr == 0.], '.')
                    plt.plot(mu_qz_predict[0][ttr == 1.], ztr[ttr == 1.], '.')
                    plt.savefig(train_dir + str(i) + 'model' + '_tr.png')
                    plt.clf()

                    # plt.figure()
                    plt.plot(mu_qz_predict[0][tte == 0.], zte[tte == 0.], '.')
                    plt.plot(mu_qz_predict[0][tte == 1.], zte[tte == 1.], '.')
                    plt.savefig(train_dir + str(i) + 'model' + '_te.png')
                    plt.clf()

                # model_results.append((score_train, score_val, score_test, best_logpvalid))
                result = (score_train, score_val, score_test, best_logpvalid, zt_y, zy_t,)
                printf('\n@@@@@@@@ Replication: {}/{}/{}, \n'
                       'tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}, \n'
                       'te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}, \n'
                       "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                       'ELBO: {:0.3f}\n\n'.format(i + 1, zt_y, zy_t,
                                                  score_train[0], score_train[1], score_train[2],
                                                  score_test[0], score_test[1], score_test[2],
                                                  score_val[0], score_val[1], score_val[2],
                                                  best_logpvalid))

                inference.finalize()

                # release plt memory
                plt.close('all')

                import gc
                gc.collect()

                # return result

                # if isinstance(y, list):
                #     results = [train_proc(y_, 'model%d' % i) for i, y_ in enumerate(y)]
                #     elbos = [result[-3] for result in results]
                #     best_elbo_id = elbos.index(max(elbos))
                #     result = results[best_elbo_id]
                #
                #     printf('#\n#\nbest elbo model', best_elbo_id)
                #
                #     pehe_val = [result[1][2] for result in results]
                #     best_pehe_id = pehe_val.index(min(pehe_val))
                #
                #     printf('#\n#\nbest pehe model', best_pehe_id)
                # else:
                #     result = train_proc(y)

                sess.close()
                if try_models:
                    # save best model result
                    def is_better(result):
                        if validate_models == 'pehe':
                            return result[1][2] < best_model_result[1][2]
                        elif validate_models == 'elbo':
                            return result[3] > best_model_result[3]

                    if best_model_result is None or is_better(result):
                        best_model_result = result

        if try_models:
            score_train, score_val, score_test, logpvalid, zt_y, zy_t = best_model_result
            printf('\n\n@@@@@@@@@@@@@@@@ Best: {}/{}/{}, \n'
                   'tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}, \n'
                   'te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}, \n'
                   "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                   'ELBO: {:0.3f}\n\n\n\n'.format(i + 1, zt_y, zy_t,
                                                  score_train[0], score_train[1], score_train[2],
                                                  score_test[0], score_test[1], score_test[2],
                                                  score_val[0], score_val[1], score_val[2],
                                                  logpvalid))
        # save result of best model
        # scores[i, :] = result[0]  # score_train
        # scores_test[i, :] = result[2]  # score_test
        scores = np.vstack((scores, result[0]))
        scores_test = np.vstack((scores_test, result[2]))

    print('CEVAE model total scores')
    means, stds = np.mean(scores, axis=0), np.std(scores, axis=0)
    printf('train ITE: {:.3f}+-{:.3f}, train ATE: {:.3f}+-{:.3f}, train PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    means, stds = np.mean(scores_test, axis=0), np.std(scores_test, axis=0)
    printf('test ITE: {:.3f}+-{:.3f}, test ATE: {:.3f}+-{:.3f}, test PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    # TODO: backup locals()
    # printf(all_args)
    printf(non_defaults)
    if hasattr(data_gen, 'config'):
        printf(data_gen.config)

    return scores, scores_test


from itertools import product

# art_config = yes_backdoor

[run(
    # try_models=True,
    # validate_models='pehe',
    art_config=ac,
    # zt_num=0,
    zy_num=1,
    # z_weight_sigma=0.1,
    # z_var_sigma=False,
    # z_cond_t=True,
    # z_switch=True,
    # metrics_sampled=True,
    # y_sampled=True,
    # gen_t=False,
    # enc_t=False,
    trans_t=lambda x: (x - .5),
    # cond_x=True,
    # nh_t=1,
    # nh_rep=3,
    # nh_reg=6,
    # zt_y=True,
    # zy_t=True,
    # weight_sigma=0.5, #[0.01*5**i for i in range(5)],  # 'trainable',
    # var_sigma=False,
    # validate='elbo',
    va_rate=0.001,
    n_sample_vld=1,
    n_sample_eval=1,
    # slim_eval=True,
    # validate_step=1,
    stop_after=10,
    epochs=1000,
    # n_repeat_sample=1,
    # index=1,
    # dec_switch=True,
    # enc_switch=True,
    matching='z',
    alpha=0,  # 'trainable',
    # beta=1.5,
    # rand_seed=None,
    # regression_only=True,
    slim_eval=True,
)
    # for _ in range(1)]
    for nl in (.2, .4, .6, .8) for ac in (y_backdoor._replace(proxy_noise_level=nl), ) + tuple(yes_backdoor._replace(proxy_noise_level=nl, proxy_t_dim=n, proxy_no_dim=3-n) for n in (0, 3))]

# for beta in [1.1 + 0.4*i for i in range(5)]]
    # for alpha, scale in product((1e10, 1e9, 1e8, 1e7, 1e6), (1., )
    #                           (0.2, 0.5, 1, 2, 5)
    #                             )
# ]
