#!/usr/bin/env python
"""CEVAE model on IHDP
"""


import os
import edward as ed
import tensorflow as tf
# from sacred import Experiment
from copy import deepcopy
import time
import inspect
from functools import partial

# ex = Experiment('cevae mods')

# from tensorflow.contrib.distributions import Normal
from edward.models import Bernoulli, Normal
from progressbar import ETA, Bar, Percentage, ProgressBar

from data import *
from evaluation import Evaluator
import numpy as np
import time
from scipy.stats import sem

from utils import config_nnbuilder
# config_nnbuilder(initializer=partial(initializers.xavier_initializer, uniform=False))
config_nnbuilder()
from utils import *

from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('-reps', type=int, default=100)
# parser.add_argument('-earl', type=int, default=10)
parser.add_argument('-lr', type=float, default=0.001)
parser.add_argument('-opt', choices=['adam', 'adamax'], default='adam')
# parser.add_argument('-epochs', type=int, default=100)
parser.add_argument('-print_every', type=int, default=10)
args = parser.parse_args()

args.true_post = True

artificial = Artificial(replications=100, samples=500, config=yes_backdoor)
dataset = IHDP(replications=args.reps)
dimx = 25
scores = np.zeros((args.reps, 3))
scores_test = np.zeros((args.reps, 3))

M = None  # batch size during training

lamba = 1e-4  # weight decay
nh, h = 3, 200  # number and size of hidden layers


# @ex.config
# def cfg():
#     equiv_confounder = True
#     equiv_conf_covar = True
#     # direct_conf_covar = False
#     # num_z_conf = 10
#     decoder_switch = False
#     decoder_sigma = True
#     # second_switch = False
#     # balanced_sample = True
#     ipm_matching = True


# @ex.automain
def run(train_dir='/Users/abel/Downloads/CEVAE/models/ihdp/',
        switch=True,
        sigma=False,
        matching=False,
        repre=True,
        alpha=100,
        nh_rep=2,
        nh_reg=1,
        n_sample_vld=100,
        n_sample_evl=100,
        validate='pehe',
        validate_step=10,
        stop_after=np.inf,
        epochs=100,
        n_repeat_sample=10,
        index=None,
        ):
    all_args = deepcopy(locals())
    print(all_args)

    global ARGSPEC
    ARGSPEC = inspect.getfullargspec(run)
    non_defaults = {arg: all_args[arg] for arg, default in zip(ARGSPEC.args, ARGSPEC.defaults) if default != all_args[arg]}

    train_dir += time.strftime("%Y-%m-%d %H:%M:%S ") + str(non_defaults).replace(':', '-') + '/'
    if not os.path.exists(train_dir):
        print("make savefolder: {0:s}...".format(train_dir))
        os.makedirs(train_dir)  # Make folder

    def printf(*args, filename=train_dir+'output.txt'):
        print(*args)
        with open(filename, 'a') as f:
            print(*args, file=f)

    def mu_sigma(input, layers, nout, var_sigma, switch, name):
        def _mu_sigma(name):
            if var_sigma:
                return fc_net(input, layers, [[nout, None], [nout, tf.nn.softplus]], name)
            else:
                mu = fc_net(input, layers, [[nout, None]], name)
                return mu, tf.ones_like(mu)

        if switch:
            return _mu_sigma(name+'_t0'), _mu_sigma(name+'_t1')
        else:
            return _mu_sigma(name)

    def normal(input, layers, nout, var_sigma, switch, t, name, reweight=1.):
        if switch:
            ((mu0, sigma0), (mu1, sigma1)) = mu_sigma(input, layers, nout, var_sigma, True, name+'_mu_sigma')
            return Normal(loc=t*mu1+(1.-t)*mu0, scale=reweight*(t*sigma1+(1.-t)*sigma0), name=name)
        else:
            mu, sigma = mu_sigma(tf.concat([input, t], 1), layers, nout, var_sigma, False, name+'_mu_sigma')
            return Normal(loc=mu, scale=reweight*sigma, name=name)

    for i, (train, valid, test, contfeats, binfeats) in enumerate(artificial.get_train_valid_test(index, saved='yes_backdoor_xt_zd1_xnl0.npz')):
        print('\nReplication {}/{}'.format(i + 1, args.reps))
        (xtr, ttr, ytr), (y_cftr, mu0tr, mu1tr) = train
        (xva, tva, yva), (y_cfva, mu0va, mu1va) = valid
        (xte, tte, yte), (y_cfte, mu0te, mu1te) = test
        evaluator_test = Evaluator(yte, tte, y_cf=y_cfte, mu0=mu0te, mu1=mu1te)
        evaluator_validation = Evaluator(yva, tva, y_cf=y_cfva, mu0=mu0va, mu1=mu1va)
        evaluator_train = Evaluator(ytr, ttr, y_cf=y_cftr, mu0=mu0tr, mu1=mu1tr)

        # reorder features with binary first and continuous after
        perm = binfeats + contfeats
        xtr, xva, xte = xtr[:, perm], xva[:, perm], xte[:, perm]

        # xalltr, talltr, yalltr = np.concatenate([xtr, xva], axis=0), np.concatenate([ttr, tva], axis=0), np.concatenate([ytr, yva], axis=0)
        # evaluator_train = Evaluator(yalltr, talltr, y_cf=np.concatenate([y_cftr, y_cfva], axis=0),
        #                             mu0=np.concatenate([mu0tr, mu0va], axis=0), mu1=np.concatenate([mu1tr, mu1va], axis=0))

        # zero mean, unit variance for y during training
        ym, ys = np.mean(ytr), np.std(ytr)
        ytr, yva = (ytr - ym) / ys, (yva - ym) / ys

        best_logpvalid = - np.inf
        best_pehevalid = np.inf
        best_atevalid = np.inf

        with tf.Graph().as_default():

            # TODO: seed
            ed.set_seed(1)
            np.random.seed(1)
            tf.set_random_seed(1)

            sess = tf.InteractiveSession()

            x_ph_bin = tf.placeholder(tf.float32, [M, len(binfeats)], name='x_bin')  # binary inputs
            x_ph_cont = tf.placeholder(tf.float32, [M, len(contfeats)], name='x_cont')  # continuous inputs
            t_ph = tf.placeholder(tf.float32, [M, 1])
            y_ph = tf.placeholder(tf.float32, [M, 1])

            x_ph = tf.concat([x_ph_bin, x_ph_cont], 1)
            activation = tf.nn.elu

            # p(y|x, t)
            if matching or repre:
                x_repre = fc_net(x_ph, nh_rep * [h], [], 'x_repre')
            else:
                x_repre = x_ph

            y = normal(x_repre, nh_reg * [h], 1, sigma, switch, t_ph, name='py', reweight=.1)

            # Create data dictionary for edward
            data = {y: y_ph}

            # 所有 inference 应该都在 inference 对象里自动完成了
            inference = ed.KLqp(data=data)

            optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)
            inference.initialize(optimizer=optimizer)

            loss_op = inference.loss

            if matching:
                h_rep_norm = x_repre / safe_sqrt(tf.reduce_sum(tf.square(x_repre), axis=1, keep_dims=True))
                imb_distance, _ = wasserstein(h_rep_norm, t_ph, 0.5, lam=10.0, its=10, sq=False,
                            backpropT=True)

                loss_op += alpha * imb_distance

            train_op = optimizer.minimize(loss_op)

            saver = tf.train.Saver(tf.contrib.slim.get_variables())
            tf.global_variables_initializer().run()

            y_ite = y

            t0te, t1te = np.zeros((xte.shape[0], 1)), np.ones((xte.shape[0], 1))
            f1te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t1te}
            f0te = {x_ph_bin: xte[:, 0:len(binfeats)], x_ph_cont: xte[:, len(binfeats):], t_ph: t0te}
            t0va, t1va = np.zeros((xva.shape[0], 1)), np.ones((xva.shape[0], 1))
            f1va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t1va}
            f0va = {x_ph_bin: xva[:, 0:len(binfeats)], x_ph_cont: xva[:, len(binfeats):], t_ph: t0va}
            t0tr, t1tr = np.zeros((xtr.shape[0], 1)), np.ones((xtr.shape[0], 1))
            f1tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t1tr}
            f0tr = {x_ph_bin: xtr[:, 0:len(binfeats)], x_ph_cont: xtr[:, len(binfeats):], t_ph: t0tr}

            count_for_stop = 0
            n_epoch, n_iter_per_epoch, idx = epochs, n_repeat_sample*int(xtr.shape[0] / 100), np.arange(xtr.shape[0])
            for epoch in range(n_epoch):
                avg_loss = 0.0

                t0 = time.time()
                widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
                pbar = ProgressBar(n_iter_per_epoch, widgets=widgets)
                pbar.start()
                np.random.shuffle(idx)
                for j in range(n_iter_per_epoch):
                    pbar.update(j)

                    batch = np.random.choice(idx, 100)
                    x_train, y_train, t_train = xtr[batch], ytr[batch], ttr[batch]

                    # if ipm_matching:
                    _, loss = sess.run([train_op, loss_op], feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                                                            x_ph_cont: x_train[:, len(binfeats):],
                                                            t_ph: t_train, y_ph: y_train})
                    avg_loss += loss
                    # else:
                    #     info_dict = inference.update(feed_dict={x_ph_bin: x_train[:, 0:len(binfeats)],
                    #                                             x_ph_cont: x_train[:, len(binfeats):],
                    #                                             t_ph: t_train, y_ph: y_train})
                    #     avg_loss += info_dict['loss']

                avg_loss = avg_loss / n_iter_per_epoch
                avg_loss = avg_loss / 100

                if epoch % args.print_every == 0:
                    y0, y1 = get_y0_y1(sess, y_ite, f0tr, f1tr, shape=ytr.shape, L=n_sample_vld)
                    y0, y1 = y0 * ys + ym, y1 * ys + ym
                    score_train = evaluator_train.calc_stats(y1, y0)
                    rmses_train = evaluator_train.y_errors(y0, y1)

                    y0, y1 = get_y0_y1(sess, y_ite, f0te, f1te, shape=yte.shape, L=n_sample_vld)
                    y0, y1 = y0 * ys + ym, y1 * ys + ym
                    score_test = evaluator_test.calc_stats(y1, y0)

                    y0, y1 = get_y0_y1(sess, y_ite, f0va, f1va, shape=yva.shape, L=n_sample_vld)
                    y0, y1 = y0 * ys + ym, y1 * ys + ym
                    score_val = evaluator_validation.calc_stats(y1, y0)


                    printf("Epoch: {}/{}, log p(x) >= {:0.3f}, \n"
                          "ite_tr: {:0.3f}, ate_tr: {:0.3f}, pehe_tr: {:0.3f}, \n"
                          # "rmse_f_tr: {:0.3f}, rmse_cf_tr: {:0.3f}, \n"
                          "ite_te: {:0.3f}, ate_te: {:0.3f}, pehe_te: {:0.3f}, \n"
                          "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                          "dt: {:0.3f}".format(epoch + 1, n_epoch, avg_loss,
                                               score_train[0], score_train[1], score_train[2],
                                               # rmses_train[0], rmses_train[1],
                                               score_test[0], score_test[1], score_test[2],
                                               score_val[0], score_val[1], score_val[2],
                                               time.time() - t0))

                    save_model = False

                    if score_val[2] < best_pehevalid:
                        printf('Improved validation bound, old: {:0.3f}, new: {:0.3f}'.format(best_pehevalid,
                                                                                              score_val[2]))
                        best_pehevalid = score_val[2]
                        if validate == 'pehe':
                            save_model = True

                if epoch % validate_step == 0 or epoch == (n_epoch - 1):
                    if save_model:
                        saver.save(sess, train_dir + 'model')
                        count_for_stop = 0
                    else:
                        count_for_stop += 1

                    if count_for_stop == stop_after:
                        break

            if validate is not None:
                saver.restore(sess, train_dir + 'model')

            y0, y1 = get_y0_y1(sess, y_ite, f0tr, f1tr, shape=ytr.shape, L=n_sample_evl)
            y0, y1 = y0 * ys + ym, y1 * ys + ym
            score_train = evaluator_train.calc_stats(y1, y0)

            y0t, y1t = get_y0_y1(sess, y_ite, f0te, f1te, shape=yte.shape, L=n_sample_evl)
            y0t, y1t = y0t * ys + ym, y1t * ys + ym
            score_test = evaluator_test.calc_stats(y1t, y0t)

            y0, y1 = get_y0_y1(sess, y_ite, f0va, f1va, shape=yva.shape, L=n_sample_evl)
            y0, y1 = y0 * ys + ym, y1 * ys + ym
            score_val = evaluator_validation.calc_stats(y1, y0)

            printf('\n@@@@@@@@ Replication: {}, \n'
                   'tr_ite: {:0.3f}, tr_ate: {:0.3f}, tr_pehe: {:0.3f}, \n'
                   'te_ite: {:0.3f}, te_ate: {:0.3f}, te_pehe: {:0.3f}, \n'
                   "ite_va: {:0.3f}, ate_va: {:0.3f}, pehe_va: {:0.3f}, \n"
                   'ELBO: {:0.3f}\n\n'.format(i + 1,
                                              score_train[0], score_train[1], score_train[2],
                                              score_test[0], score_test[1], score_test[2],
                                              score_val[0], score_val[1], score_val[2],
                                              best_logpvalid))

            sess.close()

        # save result of best model
        scores[i, :] = score_train
        scores_test[i, :] = score_test

    print('CEVAE model total scores')
    means, stds = np.mean(scores, axis=0), sem(scores, axis=0)
    printf('train ITE: {:.3f}+-{:.3f}, train ATE: {:.3f}+-{:.3f}, train PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    means, stds = np.mean(scores_test, axis=0), sem(scores_test, axis=0)
    printf('test ITE: {:.3f}+-{:.3f}, test ATE: {:.3f}+-{:.3f}, test PEHE: {:.3f}+-{:.3f}' \
          ''.format(means[0], stds[0], means[1], stds[1], means[2], stds[2]))

    printf(all_args)
    printf(non_defaults)
    return scores, scores_test


run(
    nh_reg=3,
    nh_rep=3,
    # switch=False,
    # sigma=True,
    # matching=True,
    # repre=False,
    # n_repeat_sample=1,
    n_sample_vld=1,
    n_sample_evl=1,
    # index=12,
)
