# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Tensorflow estimator code for the training of VAEs in a whole variety of flavours

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os

# Dependency imports
from absl import flags
import numpy as np
from six.moves import urllib

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python import debug as tf_debug

tfd = tfp.distributions

# Local package imports
from adversarial_attacks.losses import calculate_adversarial_kl
from image_utils.viz_helpers import image_tile_summary, draw_pca_scatter, tf_pca
from stochastic_layers.lvae_layers import make_deterministic_upward_pass, make_stochastic_downward_pass, make_deterministic_downward_pass
from stochastic_layers.generative_layers import make_generative_model, gen_path_from_z, make_p_networks
from stochastic_layers.inference_layers import make_inference_model, make_q_networks
from stochastic_layers.priors import make_gaussian_prior
from stochastic_layers.encoders import make_mlp_encoder, make_conv_encoder
from stochastic_layers.decoders import make_mlp_decoder, make_conv_decoder
from data_processing.utils import load_dataset

from loss_functions.elbo import calculate_decoder_likelihood, calculate_tc_decomp_kl_divs, calculate_elbo, calculate_kl_divs

flags.DEFINE_string("dataset",
                    default='mnist',
                    help="Dataset, choice 'mnist', 'celeba', 'dsprites'.")
flags.DEFINE_integer("n_epochs",
                     default=200,
                     help="Number of training epochs to run.")
flags.DEFINE_float("threshold",
                   help="threshold in variance for x-dims to have to be kept",
                   default=0.1)
flags.DEFINE_float("learning_rate",
                   default=0.001,
                   help="Initial learning rate.")
flags.DEFINE_multi_integer("latent_size",
                           default=[64, 32, 16, 8, 4],
                           help="Number of dimensions in the latent code (z).")
flags.DEFINE_string("activation",
                    default="leaky_relu",
                    help="Activation function for all hidden layers.")
flags.DEFINE_integer("batch_size", default=1024, help="Batch size.")
flags.DEFINE_multi_integer("hidden_size",
                           default=[1024, 512, 256, 128, 64],
                           help="Hidden layer sizes")
flags.DEFINE_integer("n_samples",
                     default=1,
                     help="Number of samples to use in encoding.")
flags.DEFINE_bool(
    "analytic_kl",
    default=False,
    help="Whether or not to use the analytic version of the KL. When set to "
    "False the E_{Z~q(Z|X)}[log p(Z)p(X|Z) - log q(Z|X)] form of the ELBO "
    "will be used. Otherwise the -KL(q(Z|X) || p(Z)) + "
    "E_{Z~q(Z|X)}[log p(X|Z)] form will be used. If analytic_kl is True, "
    "then you must also specify `mixture_components=1`.")
flags.DEFINE_string(
    "data_dir",
    default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "vae/data"),
    help="Directory where data is stored (if using real data).")
flags.DEFINE_string("model_dir",
                    default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"),
                                         "vae/"),
                    help="Directory to put the model's fit.")
flags.DEFINE_bool("fake_data",
                  default=False,
                  help="If true, uses fake data instead of MNIST.")
flags.DEFINE_bool("delete_existing",
                  default=False,
                  help="If true, deletes existing `model_dir` directory.")
flags.DEFINE_string(
    "debug",
    default="",
    help="If tensorboard, connects to tensorboard debug. Else CLI")
flags.DEFINE_string("model_type",
                    default="vae",
                    help="type of vae to implement - vae, lvae or lvae+")
flags.DEFINE_float(
    "lv_eps_z",
    help="small constant added to z-variance to avoid underflow",
    default=1e-5)
flags.DEFINE_float(
    "lv_eps_out",
    help="small constant added to x-variance to avoid underflow",
    default=1e-5)
flags.DEFINE_float(
    "l2_penalty",
    help="small constant added to x-variance to avoid underflow",
    default=0.0)
flags.DEFINE_integer(
    "n_datapoints",
    help="Number of datapoints in training datasets, must be =< 50k",
    default=50000)
flags.DEFINE_string(
    "inf_path",
    help=
    "Type of inference path for vae, either deterministic, chain, skip or inverse",
    default='chain')
flags.DEFINE_string("gen_path",
                    help="Type of generative path for vae, eith chain or skip",
                    default='chain')
flags.DEFINE_bool(
    "x_plus",
    default=False,
    help="If true, concatenates all q samples to be fed into decoder")
flags.DEFINE_string(
    "gen_inf_sharing",
    default=None,
    help=
    "If not None, shares the generative layers with the inference layers. Overrides inference path to be deterministic. Values are 'shared' or 'shared+'"
)
flags.DEFINE_float(
    "beta",
    default=1.0,
    help="Beta penalty to apply to total correlation for each stochastic layer"
)
flags.DEFINE_integer(
    "alpha",
    default=1,
    help=
    "Alpha penalty to apply to mutual information term for each stochastic layer"
)
flags.DEFINE_integer(
    "gamma",
    default=1,
    help="Gamma penalty to apply to dimension-wise KL for each stochastic layer"
)
flags.DEFINE_float(
    "lambda",
    default=0.0,
    help="Free-bit encoding to apply to KL for each stochastic layer")
flags.DEFINE_string("decomp_mode",
                    default="mws",
                    help="Mode for evaluating decomposition of KL in z")
flags.DEFINE_string(
    "x_dist",
    default="Bernoulli",
    help="choice of distribution for p(x|z) - 'Gaussian' or 'Logistic'.")
flags.DEFINE_integer(
    "n_x_mixture",
    default=10,
    help=
    "Number of mixture components in our likelihood function p(x|z), if Logistic"
)
flags.DEFINE_string(
    "enc_dec_type",
    default='mlp',
    help="type of encoder and decoder networks to use, conv or mlp")

FLAGS = flags.FLAGS


def model_fn(features, labels, mode, params, config):
    """Builds the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
    del config

    # --- If mode is predict then we pass a dict to features
    EMBEDS, ADV, DENOISE = False, False, False
    # --- If we've activated adversarial mode then we extract data specific to
    # --- adversarial attacks, such as the target image, the target image embedding
    # --- and the regularisation constant C regulating the noise to performance
    # --- tradeoff in the adversarial attack
    if (mode == tf.estimator.ModeKeys.PREDICT) and (isinstance(
            features, dict)) and ('ADV' in features):
        # --- The noise mask
        d = tf.cast(features['d'], tf.float32)
        # --- mean and variance of the embedding of the target image
        zt_mu, zt_var = features['zt_mu'], features['zt_sig']**2
        # --- target image
        target = features['target']
        # --- noise-performance regularization constant
        C = tf.cast(features['C'], tf.float32)
        # --- The type of adversarial attack, either 'latent' or 'output'
        ADV = features['ADV']
        # --- Format feature data and add noise mask
        features = tf.reshape(features['features'],
                              (-1, np.prod(params['IMAGE_SHAPE'])))

        features = features + d
    # --- If we are using the EMBED mode -- pass in a series of z_l (top layer)
    # --- embedding values to get a reconstruction
    elif (mode == tf.estimator.ModeKeys.PREDICT) and (isinstance(
            features, dict) and ('EMBEDS' in features)):
        # --- Extract label data and features from the input dict
        EMBEDS = features['EMBEDS']
        if 'fixed_layers' in features:
            fixed_layers = features['fixed_layers']
    # --- If we're in denoising mode
    elif (mode == tf.estimator.ModeKeys.PREDICT) and (isinstance(
            features, dict) and ('DENOISE' in features)):
        # --- Extract label data and features from the input dict
        original_input = features['original']
        features = features['DENOISE']
        DENOISE = True
    # --- If the mode is NOT adversarial
    elif (mode == tf.estimator.ModeKeys.PREDICT) and (isinstance(
            features, dict)):
        # --- Extract label data and features from the input dict
        labels = features['labels']
        features = tf.reshape(features['features'],
                              (-1, np.prod(params['IMAGE_SHAPE'])))
    # --- Else if training or evaluating input is simply the feature data
    else:
        features = tf.reshape(features, (-1, np.prod(params['IMAGE_SHAPE'])))

    if params["analytic_kl"] and params["mixture_components"] != 1:
        raise NotImplementedError(
            "Using `analytic_kl` is only supported when `mixture_components = 1` "
            "since there's no closed form otherwise.")

    if params['enc_dec_type'] == 'conv':
        # --- Make encoder network from input data to first stochastic layer
        encoder = make_conv_encoder(params["activation"],
                                    params["latent_size"][0],
                                    image_shape=params["IMAGE_SHAPE"],
                                    x_dist=params["x_dist"],
                                    eps=params["lv_eps_z"])
        # --- Make decoder network from first stochastic layer to input space
        decoder = make_conv_decoder(params["activation"],
                                    params["latent_size"],
                                    image_shape=params["IMAGE_SHAPE"],
                                    l2_penalty=params["l2_penalty"],
                                    x_dist=params["x_dist"],
                                    n_x_mixture=params["n_x_mixture"],
                                    x_plus=params["x_plus"])
    else:
        encoder = make_mlp_encoder(params["activation"],
                                   params["latent_size"][0],
                                   params["hidden_size"][0],
                                   eps=params["lv_eps_z"])

        # --- Make decoder network from first stochastic layer to input space
        decoder = make_mlp_decoder(params["activation"],
                                   params["latent_size"],
                                   params["hidden_size"][0], [params["n_x"]],
                                   l2_penalty=params["l2_penalty"],
                                   x_plus=params["x_plus"])

    # --- Define prior for the final stochastic layer
    latent_prior = make_gaussian_prior(params["latent_size"][-1],
                                       name='p_z_' +
                                       str(params['stochastic_depth'] - 1))

    # --- Make generative networks for the EMBED mode of predict
    q_nets, d_nets = make_q_networks(params)
    p_nets = make_p_networks(params)

    if (mode == tf.estimator.ModeKeys.PREDICT) and (EMBEDS != False):
        preds = gen_path_from_z(decoder, p_nets, EMBEDS, params)
        return tf.estimator.EstimatorSpec(mode, predictions=preds)

    # --- Make callables to generate networks in inference model
    q_layers, q_z_samples, d_layers = make_inference_model(
        q_nets, d_nets, encoder, features, params)
    # --- Make callables to generate networks in generative model
    p_layers, q_layers, q_z_samples = make_generative_model(
        q_nets, p_nets, q_layers, q_z_samples, d_layers, latent_prior, params)

    # --- Calculate decoder likelihood given samples from the first q layer
    decoder_likelihood, log_p_x_z, mean_log_p_x_z = calculate_decoder_likelihood(
        decoder, features, q_z_samples, params["x_plus"])

    # --- Calculate KL divs across stochastic variables
    kl_z, mean_kl_z, eval_migs, _ = calculate_tc_decomp_kl_divs(
        p_layers,
        q_layers,
        q_z_samples,
        params=params,
        beta=params["beta"],
        alpha=params["alpha"],
        gamma=params["gamma"],
        lamb=params["lambda"],
        factors_batch=labels)

    og_kl_z, og_mean_kl_z = calculate_kl_divs(p_layers,
                                              q_layers,
                                              q_z_samples,
                                              beta=params["beta"])

    tf.summary.scalar("kl_z/kl_diff_perc",
                      100 * (mean_kl_z - og_mean_kl_z) / og_mean_kl_z)

    # --- If we're predicting we can return results from here
    if mode == tf.estimator.ModeKeys.PREDICT:
        # --- Instantiate our 'preds' dictionary
        # --- * mu and sig are the mean and std-dev of the embedding of the feature data
        # --- --- @NOTE that we are stacking along z dimension and hence these values must be sliced approriately
        # --- * recon is the reconstructed image
        # --- * migs are the mig scores for the feature data for each stochastic layer
        # --- * elbo is p(x) for the feature data

        if (DENOISE != False):
            preds = dict(
                orig_likelihood=decoder_likelihood.log_prob(original_input))

        else:
            mu_packaged = tf.concat([
                tf.reshape(
                    q['mu'],
                    (-1, params['batch_size'], params['latent_size'][i]))
                for i, q in enumerate(q_layers)
            ],
                                    axis=2)

            sig_packaged = tf.concat([
                tf.reshape(
                    q['var'],
                    (-1, params['batch_size'], params['latent_size'][i]))
                for i, q in enumerate(q_layers)
            ],
                                     axis=2)

            kl_z_dim = tf.concat([
                q["prob"].log_prob(q_z_samples[i]) -
                p["prob"].log_prob(q_z_samples[i])
                for i, (p, q) in enumerate(zip(p_layers, q_layers))
            ],
                                 axis=2)

            preds = dict(mu=mu_packaged,
                         sig=sig_packaged,
                         recon=tf.reduce_mean(decoder_likelihood.mean(),
                                              axis=0),
                         kl_z_dim=kl_z_dim,
                         migs=tf.reshape(tf.stack(eval_migs), (-1, 1)),
                         elbo=tf.reshape(-(og_kl_z + log_p_x_z), (-1, 1)))

            preds['encoder_l2'], preds['decoder_l2'] = 0, 0
            for var in tf.trainable_variables():
                if 'encoder' in var.name:
                    preds['encoder_l2'] += tf.nn.l2_loss(var)
                elif 'decoder' in var.name:
                    preds['decoder_l2'] += tf.nn.l2_loss(var)
            # --- If the mode is adversarial
            if (ADV != False):
                # --- Calculate the likelihood of the target image given our adversarial image
                # --- ie p(x_t|z_a) where z_a is the embedding of the adversarial image
                preds['target_likelihood'] = decoder_likelihood.log_prob(
                    target)
                # --- Calculate p(x_t) for the target image
                preds['adv_elbo'] = -preds['target_likelihood'] + og_kl_z
                # --- Calculate the adversarial KL: KL(z_a,z_t)
                preds['adv_kl'] = calculate_adversarial_kl(
                    q_layers, zt_mu, zt_var, params)
                # --- If ADV is 'output' use the adv_elbo as loss, else if ADV is
                # --- 'latent' use the adv_kl
                preds['adv_loss'] = tf.cond(tf.equal(
                    ADV, 'latent'), lambda: preds['adv_kl'], lambda: preds[
                        'adv_elbo'])

                # --- Add l2 norm of disturbance to regularize adversarial attack
                preds['adv_loss'] += C * tf.nn.l2_loss(d)
                # --- Calculate gradients of adversarial loss w.r.t noise mask
                preds['adv_grad'] = tf.gradients(preds['adv_loss'], [d])[0]

        return tf.estimator.EstimatorSpec(mode, predictions=preds)

    # --- Perform variational inference by minimizing the -ELBO.
    # --- Extract current step in training
    global_step = tf.train.get_or_create_global_step()

    elbo, importance_weighted_elbo = calculate_elbo(
        kl_z, log_p_x_z, params["n_samples"], warm_up_beta=tf.constant(1.0))

    # --- Use importance weighted elbo as loss
    loss = -elbo

    # --- Image summaries
    image_tile_summary("input",
                       tf.cast(features, dtype=tf.float32),
                       rows=1,
                       cols=16,
                       params=params)
    n_rows = min([params["n_samples"], 3])
    n_cols = min([params["batch_size"], 16])
    image_tile_summary("recon/sample",
                       decoder_likelihood.sample()[:n_rows, :n_cols],
                       rows=n_rows,
                       cols=n_cols,
                       params=params)
    image_tile_summary("recon/mean",
                       decoder_likelihood.mean()[:n_rows, :n_cols],
                       rows=n_rows,
                       cols=n_cols,
                       params=params)
    for i, approx_posterior_sample in enumerate(q_z_samples):
        # --- Scatter plot them dists
        image_tensor = draw_pca_scatter(
            tf_pca(approx_posterior_sample, params["latent_size"][i]), labels,
            params["n_samples"])
        tf.summary.image('pca/layer_' + str(i), image_tensor)

    # --- Decode samples from the prior for visualization.
    if params["x_plus"]:
        samples = [
            tf.squeeze(p['prob'].sample(1), axis=0) for p in p_layers[:-1]
        ]
        samples.append(
            tf.reshape(
                p_layers[-1]['prob'].sample(params["n_samples"] *
                                            params["batch_size"]),
                (params["n_samples"], params["batch_size"],
                 params["latent_size"][-1])))
        samples = tf.concat(samples, axis=2)
    elif len(params['latent_size']) > 1:
        samples = p_layers[0]['prob'].sample(1)
    else:
        samples = tf.reshape(
            p_layers[0]['prob'].sample(params["batch_size"] *
                                       params["n_samples"]),
            (params["n_samples"], params["batch_size"], -1))

    random_image = decoder(samples)
    image_tile_summary("random/sample",
                       tf.cast(random_image.sample(), dtype=tf.float32),
                       rows=4,
                       cols=4,
                       params=params)
    image_tile_summary("random/mean",
                       random_image.mean(),
                       rows=4,
                       cols=4,
                       params=params)

    # --- Perform variational inference by minimizing the -ELBO.
    learning_rate = tf.train.cosine_decay(params["learning_rate"], global_step,
                                          params["max_steps"])
    tf.summary.scalar("learning_rate", learning_rate)

    # --- Use Adam as optimizer to get good 'out-of-the-box' performance
    optimizer = tf.train.AdamOptimizer(learning_rate,
                                       beta1=0.9,
                                       beta2=0.999,
                                       epsilon=1e-3)
    gradients, variables = zip(*optimizer.compute_gradients(loss))
    # --- Clip the gradients by norm to avoid gradient explosions.
    max_norm = 4
    gradients, _ = tf.clip_by_global_norm(gradients, max_norm)

    train_op = optimizer.apply_gradients(zip(gradients, variables),
                                         global_step=global_step)

    eval_metrics = {
        "elbo": tf.metrics.mean(elbo),
        "elbo/importance_weighted": tf.metrics.mean(importance_weighted_elbo),
        "kl_z": tf.metrics.mean(mean_kl_z),
        "log_p_x_z": tf.metrics.mean(mean_log_p_x_z),
    }

    for i, mig in enumerate(eval_migs):
        eval_metrics["MIG/layer_" + str(i)] = tf.metrics.mean(mig)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=eval_metrics)


def main(argv):
    del argv  # unused

    assert len(FLAGS.hidden_size) == len(FLAGS.latent_size)

    assert (FLAGS.model_type in ["vae+", "vae", "lvae", "lvae+"])
    params = FLAGS.flag_values_dict()
    params['stochastic_depth'] = len(params['latent_size'])

    params["activation"] = getattr(tf.nn, params["activation"])
    # --- Weakly linear ReLu
    params["activation"].alpha = 0.1

    # --- If old checkpoints exist then delete and start from scratch
    if FLAGS.delete_existing and tf.gfile.Exists(FLAGS.model_dir):
        tf.logging.warn("Deleting old log directory at {}".format(
            FLAGS.model_dir))
        tf.gfile.DeleteRecursively(FLAGS.model_dir)
    tf.gfile.MakeDirs(FLAGS.model_dir)

    train_input_fn, eval_input_fn, params = load_dataset(FLAGS.dataset, params)

    params["epoch_steps"] = int(
        float(params["n_datapoints"]) / float(params["batch_size"]))

    params["max_steps"] = int(params["n_epochs"]) * params["epoch_steps"]

    # --- If tensorboard debug needs to connect to tensorboard with debug ]
    # --- --- listening on port 5000
    if FLAGS.debug == "tensorboard":
        hooks = [
            tf_debug.TensorBoardDebugHook(
                grpc_debug_server_addresses="localhost:5000",
                thread_name_filter="MainThread$")
        ]
    # --- Else use tfdb in CLI
    elif FLAGS.debug == "cli":
        hooks = [tf_debug.LocalCLIDebugHook(thread_name_filter="MainThread$")]
    else:
        hooks = []

    estimator = tf.estimator.Estimator(
        model_fn,
        params=params,
        config=tf.estimator.RunConfig(
            model_dir=FLAGS.model_dir,
            save_checkpoints_steps=params["epoch_steps"],
            keep_checkpoint_max=2,
            save_summary_steps=params["epoch_steps"]))

    # --- We force the graph to finalize because of some memory leaks we had.
    tf.get_default_graph().finalize()

    # --- Setup our train and eval specs.
    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                        max_steps=params["max_steps"],
                                        hooks=hooks)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                      steps=None,
                                      throttle_secs=600)

    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)


if __name__ == "__main__":
    tf.app.run()
