import tensorflow.compat.v1 as tf
import os
import json


def get_activation_fn(type='relu'):
  """
    Return tensorflow activation function given string name.
    Args:
        type:

    Returns:
    """
  if type == 'relu':
    return tf.nn.relu
  elif type == 'elu':
    return tf.nn.elu
  elif type == 'tanh':
    return tf.nn.tanh
  elif type == 'sigmoid':
    return tf.nn.sigmoid
  elif type == 'softplus':
    return tf.nn.softplus
  elif type == None:
    return None
  else:
    raise Exception('Activation function is not supported.')


def linear(input,
           output_size,
           activation_fn=None,
           batch_norm=False,
           is_training=True):
  """
    Creates a linear layer.
    Args:
        input:
        output_size:
        activation_fn : tensorflow activation function such as tf.nn.relu,
          tf.nn.sigmoid, etc.
        batch_norm (bool): whether use batch normalization layer or not.
        is_training (bool): whether in training mode or not.

    Returns:
    """
  dense_layer = tf.layers.dense(input, output_size)
  if batch_norm == True and activation_fn is not None:
    dense_layer = tf.layers.batch_normalization(
        dense_layer, axis=1, training=is_training)

  if isinstance(activation_fn, str):
    activation_fn = get_activation_fn(activation_fn)
  if activation_fn is not None:
    dense_layer = activation_fn(dense_layer)
  return dense_layer


def fully_connected_layer(input, is_training=True, **kwargs):
  """
    Creates fully connected layers.
    Args:
        input:
        is_training (bool): whether in training mode or not.
        **kwargs: `size`, `activation_fn`, `num_layers`

    Returns:
    """
  activation_fn = get_activation_fn(kwargs.get('activation_fn', 'relu'))
  num_layers = kwargs.get('num_layers', 1)
  hidden_size = kwargs.get('size', 256)
  use_batch_norm = kwargs.get('use_batch_norm', False)

  hidden_layer = input
  for i in range(num_layers):
    hidden_layer = linear(
        hidden_layer,
        hidden_size,
        activation_fn=activation_fn,
        batch_norm=use_batch_norm,
        is_training=is_training)
  return hidden_layer


def get_reduce_loss_func(type='sum_mean', seq_len=None):
  """

    Args:
        loss: expects [batch_size, loss_size] or [batch_size, sequence_length,
          loss_size].
        type: "sum_mean", "mean", "sum".

    Returns:
    """

  def reduce_sum_mean(loss):
    """
        Average batch loss. First calculates per sample loss by summing over the
        second and third dimensions and then
        takes the average.
        """
    rank = len(loss.get_shape())
    if rank > 3 or rank < 2:
      raise Exception('Loss rank must be 2 or 3.')

    if rank == 3:
      return tf.reduce_mean(tf.reduce_sum(loss, axis=[1, 2]))
    elif rank == 2:
      return tf.reduce_mean(tf.reduce_sum(loss, axis=[1]))

  def reduce_mean_per_step(loss):
    """
        First calculates average loss per sample (loss per step), and then takes
        average over samples. Loss per step
        requires sequence length. If all samples have the same sequence length
        then this is equivalent to `mean`.
        """
    rank = len(loss.get_shape())
    if rank > 3 or rank < 2:
      raise Exception('Loss rank must be 2 or 3.')

    # Calculate loss per step.
    if rank == 3:
      step_loss_per_sample = tf.reduce_sum(
          loss, axis=[1, 2]) / tf.cast(seq_len, tf.float32)
    elif rank == 2:
      step_loss_per_sample = tf.reduce_sum(
          loss, axis=[1]) / tf.cast(seq_len, tf.float32)
    # Calculate average (per step) sample loss.
    return tf.reduce_mean(step_loss_per_sample)

  if type == 'sum_mean':
    return reduce_sum_mean
  elif type == 'sum':
    return tf.reduce_sum
  elif type == 'mean':
    return tf.reduce_mean
  elif type == 'mean_per_step':
    return reduce_mean_per_step


def get_rnn_cell(**kwargs):
  """
    Creates an rnn cell object.
    Args:
        **kwargs: must contain `cell_type`, `size` and `num_layers` key-value
          pairs. `dropout_keep_prob` is optional. `dropout_keep_prob` can be a
          list of ratios where each cell has different dropout ratio in a
          stacked architecture. If it is a scalar value, then the whole
          architecture (either a single cell or stacked cell) has one
          DropoutWrapper.

    Returns:
    """

  cell_type = kwargs['cell_type']
  size = kwargs['size']
  num_layers = kwargs['num_layers']
  dropout_keep_prob = kwargs.get('dropout_keep_prob', 1.0)

  separate_dropout = False
  if isinstance(dropout_keep_prob,
                list) and len(dropout_keep_prob) == num_layers:
    separate_dropout = True

  if cell_type.lower() == 'LSTM'.lower():
    rnn_cell_constructor = tf.nn.rnn_cell.LSTMCell
  elif cell_type.lower() == 'GRU'.lower():
    rnn_cell_constructor = tf.nn.rnn_cell.GRUCell
  elif cell_type.lower() == 'LayerNormBasicLSTMCell'.lower():
    rnn_cell_constructor = tf.compat.v1.rnn.LayerNormBasicLSTMCell
  else:
    raise Exception('Unsupported RNN Cell.')

  rnn_cells = []
  for i in range(num_layers):
    cell = rnn_cell_constructor(size)
    if separate_dropout:
      cell = tf.compat.v1.rnn.DropoutWrapper(
          cell,
          input_keep_prob=dropout_keep_prob[i],
          output_keep_prob=dropout_keep_prob,
          state_keep_prob=1,
          dtype=tf.float32,
          seed=1)
    rnn_cells.append(cell)

  if num_layers > 1:
    cell = tf.nn.rnn_cell.MultiRNNCell(cells=rnn_cells, state_is_tuple=True)
  else:
    cell = rnn_cells[0]

  if separate_dropout and dropout_keep_prob < 1.0:
    cell = tf.compat.v1.rnn.DropoutWrapper(
        cell,
        input_keep_prob=dropout_keep_prob,
        output_keep_prob=dropout_keep_prob,
        state_keep_prob=1,
        dtype=tf.float32,
        seed=1)
  return cell


import numpy as np


def logli_normal_bivariate(x, mu, sigma, rho, reduce_sum=False):
  """
    Bivariate Gaussian log-likelihood. Rank of arguments is expected to be 3.
    Args:
        x: data samples with shape (batch_size, num_time_steps, data_size).
        mu:
        sigma: standard deviation.
        rho:
        reduce_sum: False, None or list of axes.

    Returns:
    """
  last_axis = tf.rank(x) - 1
  x1, x2 = tf.split(x, 2, axis=last_axis)
  mu1, mu2 = tf.split(mu, 2, axis=last_axis)
  sigma1, sigma2 = tf.split(sigma, 2, axis=last_axis)

  with tf.name_scope('logli_normal_bivariate'):
    x_mu1 = tf.subtract(x1, mu1)
    x_mu2 = tf.subtract(x2, mu2)
    Z = tf.square(tf.div(x_mu1, tf.maximum(1e-9, sigma1))) + \
        tf.square(tf.div(x_mu2, tf.maximum(1e-9, sigma2))) - \
        2*tf.div(tf.multiply(rho, tf.multiply(x_mu1, x_mu2)), tf.maximum(1e-9, tf.multiply(sigma1, sigma2)))

    rho_square_term = tf.maximum(1e-9, 1 - tf.square(rho))
    log_regularize_term = tf.log(
        tf.maximum(
            1e-9, 2 * np.pi *
            tf.multiply(tf.multiply(sigma1, sigma2), tf.sqrt(rho_square_term))))
    log_power_e = tf.div(Z, 2 * rho_square_term)
    result = -(log_regularize_term + log_power_e)

    if reduce_sum is False:
      return result
    else:
      return tf.reduce_sum(result, reduce_sum)


def logli_normal_diag_cov(x, mu, sigma, reduce_sum=False):
  """
    Log-likelihood of Gaussian with diagonal covariance matrix.
    Args:
        x:
        mu:
        sigma: standard deviation.
        reduce_sum:

    Returns:
    """
  with tf.name_scope('logli_normal_diag_cov'):
    ssigma2 = tf.maximum(1e-6, tf.square(sigma) * 2)
    denom_log = tf.log(tf.sqrt(np.pi * ssigma2))
    norm = tf.square(tf.subtract(x, mu))
    z = tf.div(norm, ssigma2)
    result = -(z + denom_log)

    if reduce_sum is False:
      return result
    else:
      return tf.reduce_sum(result, reduce_sum)


def logli_bernoulli(x, theta, reduce_sum=False):
  """
    Bernoulli log-likelihood.
    Args:
        x:
        theta:
        reduce_sum:

    Returns:
    """
  with tf.name_scope('logli_bernoulli'):
    result = (
        tf.multiply(x, tf.log(tf.maximum(1e-9, theta))) + tf.multiply(
            (1 - x), tf.log(tf.maximum(1e-9, 1 - theta))))

    if reduce_sum is False:
      return result
    else:
      return tf.reduce_sum(result, reduce_sum)


def kld_normal_isotropic(mu1, sigma1, mu2, sigma2, reduce_sum=False):
  """
    Kullback-Leibler divergence between two isotropic Gaussian distributions.
    Args:
        mu1:
        sigma1: standard deviation.
        mu2:
        sigma2: standard deviation.
        reduce_sum:

    Returns:
    """
  with tf.name_scope('kld_normal_isotropic'):
    result = tf.reduce_sum(
        0.5 * (2 * tf.log(tf.maximum(1e-9, sigma2)) -
               2 * tf.log(tf.maximum(1e-9, sigma1)) +
               (tf.square(sigma1) + tf.square(mu1 - mu2)) /
               tf.maximum(1e-9, (tf.square(sigma2))) - 1),
        keepdims=True,
        axis=-1)

    if reduce_sum is False:
      return result
    else:
      return tf.reduce_sum(result, reduce_sum)


"""VRNN cell classes. Cell functionality is decomposed into basic methods so that minor variations can be easily implemented by following OOP paradigm.

`build_training_graph` and `build_sampling_graph` methods are used to create
the cell by tensorflow's forward call.
"""


class VRNNCell(tf.nn.rnn_cell.RNNCell):
  """
    Variational RNN cell.
    Training time behaviour: draws latent vectors from approximate posterior
    distribution and tries to decrease the
    discrepancy between prior and the approximate posterior distributions.
    Sampling time behaviour: draws latent vectors from the prior distribution to
    synthesize a sample. This synthetic
    sample is then used to calculate approximate posterior distribution which is
    fed to RNN to update the state.
    The inputs to the forward call are not used and can be dummy.
    """

  def __init__(self, reuse, mode, config):
    """
        Args:
            config (dict): In addition to standard <key, value> pairs, stores
              the following dictionaries for rnn and output configurations.
              config['output'] = {} config['output']['keys']
              config['output']['dims'] config['output']['activation_funcs']
              config['*_rnn'] = {}
                config['*_rnn']['num_layers'] (default: 1)
                config['*_rnn']['cell_type'] (default: lstm)
                config['*_rnn']['size'] (default: 512)
            reuse: reuse model parameters.
            mode: 'training' or 'sampling'.
    """
    self.input_dims = config['input_dims']
    self.h_dim = config['latent_hidden_size']
    self.z_dim = config['latent_size']
    self.additive_q_mu = config['additive_q_mu']

    self.dropout_keep_prob = config.get('input_keep_prop', 1)
    self.num_linear_layers = config.get('num_fc_layers', 1)
    self.use_latent_h_in_outputs = config.get('use_latent_h_in_outputs', True)
    self.use_batch_norm = config['use_batch_norm_fc']

    self.reuse = reuse
    self.mode = mode
    self.is_sampling = mode == 'sampling'

    if not (mode == 'training'):
      self.dropout_keep_prob = 1.0

    self.output_config = config['output']

    self.output_size_ = [self.z_dim] * 4
    self.output_size_.extend(self.output_config['dims']
                            )  # q_mu, q_sigma, p_mu, p_sigma + model outputs

    self.state_size_ = []
    # Optional. Linear layers will be used if not passed.
    self.input_rnn = False
    if 'input_rnn' in config and not (config['input_rnn'] is None) and len(
        config['input_rnn'].keys()) > 0:
      self.input_rnn = True
      self.input_rnn_config = config['input_rnn']

      self.input_rnn_cell = get_rnn_cell(
          scope='input_rnn', **config['input_rnn'])
      self.state_size_.append(self.input_rnn_cell.state_size)

    self.latent_rnn_config = config['latent_rnn']
    self.latent_rnn_cell_type = config['latent_rnn']['cell_type']
    self.latent_rnn_cell = get_rnn_cell(
        scope='latent_rnn', **config['latent_rnn'])
    self.state_size_.append(self.latent_rnn_cell.state_size)

    # Optional. Linear layers will be used if not passed.
    self.output_rnn = False
    if 'output_rnn' in config and not (config['output_rnn'] is None) and len(
        config['output_rnn'].keys()) > 0:
      self.output_rnn = True
      self.output_rnn_config = config['output_rnn']

      self.output_rnn_cell = get_rnn_cell(
          scope='output_rnn', **config['output_rnn'])
      self.state_size_.append(self.output_rnn_cell.state_size)

    self.activation_func = get_activation_fn(
        config.get('fc_layer_activation_func', 'relu'))
    self.sigma_func = get_activation_fn('softplus')

  @property
  def state_size(self):
    return tuple(self.state_size_)

  @property
  def output_size(self):
    return tuple(self.output_size_)

  #
  # Auxiliary functions
  #
  def draw_sample(self):
    """
        Draws a sample by using cell outputs.
        Returns:
        """
    # Select mu as sample.
    return self.output_components['out_mu']

  def reparametrization(self, mu, sigma, scope):
    """
        Given an isotropic normal distribution (mu and sigma), draws a sample by
        using reparametrization trick:
        z = mu + sigma*epsilon
        Args:
            mu: mean of isotropic Gaussian distribution.
            sigma: standard deviation of isotropic Gaussian distribution.

        Returns:
        """
    with tf.variable_scope(scope):
      eps = tf.random_normal(sigma.get_shape(), 0.0, 1.0, dtype=tf.float32)
      z = tf.add(mu, tf.multiply(sigma, eps))

      return z

  def phi(self, input_, scope, reuse=None):
    """
        A fully connected layer to increase model capacity and learn and
        intermediate representation. It is reported to
        be useful in https://arxiv.org/pdf/1506.02216.pdf
        Args:
            input_:
            scope:

        Returns:
        """
    with tf.variable_scope(scope, reuse=reuse):
      phi_hidden = input_
      for i in range(self.num_linear_layers):
        phi_hidden = linear(
            phi_hidden,
            self.h_dim,
            self.activation_func,
            batch_norm=self.use_batch_norm)

      return phi_hidden

  def latent(self, input_, scope):
    """
        Creates mu and sigma components of a latent distribution. Given an input
        layer, first applies a fully connected
        layer and then calculates mu & sigma.
        Args:
            input_:
            scope:

        Returns:
        """
    with tf.variable_scope(scope):
      latent_hidden = linear(
          input_,
          self.h_dim,
          self.activation_func,
          batch_norm=self.use_batch_norm)
      with tf.variable_scope('mu'):
        mu = linear(latent_hidden, self.z_dim)
      with tf.variable_scope('sigma'):
        sigma = linear(latent_hidden, self.z_dim, self.sigma_func)

      return mu, sigma

  def parse_rnn_state(self, state):
    """
        Sets self.latent_h and rnn states.
        Args:
            state:

        Returns:
        """
    latent_rnn_state_idx = 0
    if self.input_rnn is True:
      self.input_rnn_state = state[0]
      latent_rnn_state_idx = 1
    if self.output_rnn is True:
      self.output_rnn_state = state[latent_rnn_state_idx + 1]

    self.latent_rnn_state = state[latent_rnn_state_idx]

    if self.latent_rnn_cell_type.lower() == 'gru':
      self.latent_h = self.latent_rnn_state
    else:
      self.latent_h = self.latent_rnn_state.h

  #
  # Functions to build graph.
  #
  def build_training_graph(self, input_, state):
    """
        Args:
            input_:
            state:

        Returns:
        """
    self.parse_rnn_state(state)
    self.input_layer(input_, state)
    self.input_layer_hidden()

    self.latent_p_layer()
    self.latent_q_layer()
    self.phi_z = self.phi_z_q

    self.output_layer_hidden()
    self.output_layer()
    self.update_latent_rnn_layer()

  def build_sampling_graph(self, input_, state):
    self.parse_rnn_state(state)
    self.latent_p_layer()
    self.phi_z = self.phi_z_p

    self.output_layer_hidden()
    self.output_layer()

    # Draw a sample by using predictive distribution.
    synthetic_sample = self.draw_sample()
    self.input_layer(synthetic_sample, state)
    self.input_layer_hidden()
    self.latent_q_layer()
    self.update_latent_rnn_layer()

  def input_layer(self, input_, state):
    """
        Set self.x by applying dropout.
        Args:
            input_:
            state:

        Returns:
        """
    with tf.variable_scope('input'):
      input_components = tf.split(input_, self.input_dims, axis=1)
      self.x = input_components[0]

  def input_layer_hidden(self):
    if self.input_rnn is True:
      self.phi_x_input, self.input_rnn_state = self.input_rnn_cell(
          self.x, self.input_rnn_state, scope='phi_x_input')
    else:
      self.phi_x_input = self.phi(self.x, scope='phi_x_input')

    if self.dropout_keep_prob < 1.0:
      self.phi_x_input = tf.nn.dropout(
          self.phi_x_input, keep_prob=self.dropout_keep_prob)

  def latent_q_layer(self):
    input_latent_q = tf.concat((self.phi_x_input, self.latent_h), axis=1)
    if self.additive_q_mu:
      q_mu_delta, self.q_sigma = self.latent(input_latent_q, scope='latent_z_q')
      self.q_mu = q_mu_delta + self.p_mu
    else:
      self.q_mu, self.q_sigma = self.latent(input_latent_q, scope='latent_z_q')

    q_z = self.reparametrization(self.q_mu, self.q_sigma, scope='z_q')
    self.phi_z_q = self.phi(q_z, scope='phi_z', reuse=True)

  def latent_p_layer(self):
    input_latent_p = tf.concat((self.latent_h), axis=1)
    self.p_mu, self.p_sigma = self.latent(input_latent_p, scope='latent_z_p')

    p_z = self.reparametrization(self.p_mu, self.p_sigma, scope='z_p')
    self.phi_z_p = self.phi(p_z, scope='phi_z')

  def output_layer_hidden(self):
    if self.use_latent_h_in_outputs is True:
      output_layer_hidden = tf.concat((self.phi_z, self.latent_h), axis=1)
    else:
      output_layer_hidden = tf.concat((self.phi_z), axis=1)

    if self.output_rnn is True:
      self.phi_x_output, self.output_rnn_state = self.output_rnn_cell(
          output_layer_hidden, self.output_rnn_state, scope='phi_x_output')
    else:
      self.phi_x_output = self.phi(output_layer_hidden, scope='phi_x_output')

  def output_layer(self):
    self.output_components = {}
    for key, size, activation_func in zip(
        self.output_config['keys'], self.output_config['dims'],
        self.output_config['activation_funcs']):
      with tf.variable_scope(key):
        output_component = linear(
            self.phi_x_output,
            size,
            activation_fn=get_activation_fn(activation_func))
        self.output_components[key] = output_component

  def update_latent_rnn_layer(self):
    input_latent_rnn = tf.concat((self.phi_x_input, self.phi_z), axis=1)
    self.latent_rnn_output, self.latent_rnn_state = self.latent_rnn_cell(
        input_latent_rnn, self.latent_rnn_state)

  def __call__(self, input_, state, scope=None):
    with tf.variable_scope(scope or type(self).__name__, reuse=self.reuse):
      if self.is_sampling:
        self.build_sampling_graph(input_, state)
      else:
        self.build_training_graph(input_, state)

      # Prepare cell output.
      vrnn_cell_output = [self.q_mu, self.q_sigma, self.p_mu, self.p_sigma]
      for key in self.output_config['keys']:
        vrnn_cell_output.append(self.output_components[key])

      # Prepare cell state.
      vrnn_cell_state = []
      if self.input_rnn:
        vrnn_cell_state.append(self.input_rnn_state)

      vrnn_cell_state.append(self.latent_rnn_state)

      if self.output_rnn:
        vrnn_cell_state.append(self.output_rnn_state)

      return tuple(vrnn_cell_output), tuple(vrnn_cell_state)


class VRNNGmmCell(VRNNCell):
  """
    Variational RNN cell with GMM latent space option. See parent class for
    method documentation. Please note that here
    we use a GMM to learn a continuous representation for categorical inputs.
    The gradients don't flow through the
    GMM. The model is still trained with classification loss.
    Training time behaviour: draws latent vectors from approximate posterior
    distribution and tries to decrease the
    discrepancy between prior and the approximate posterior distributions.
    Sampling time behaviour: draws latent vectors from the prior distribution to
    synthesize a sample. This synthetic
    sample is then used to calculate approximate posterior distribution which is
    fed to RNN to update the state.
    The inputs to the forward call are not used and can be dummy.
    """

  def __init__(self, reuse, mode, config):
    super(VRNNGmmCell, self).__init__(reuse, mode, config)

    # Latent variable z.
    self.use_temporal_latent_space = config.get('use_temporal_latent_space',
                                                True)

    # Latent variable \pi (i.e., content branch).
    # Options for content branch samples: soft_gmm, hard_gmm, pi. (default: hard_gmm)
    self.use_pi_as_content = config.get(
        'use_pi_as_content', False
    )  # If True, label probabilities (\pi) or one-hot-encoded labels will be used as content by bypassing GMM.
    self.is_gmm_active = not (self.use_pi_as_content)
    self.use_soft_gmm = config.get('use_soft_gmm', False)

    self.use_real_pi_labels = config.get(
        'use_real_pi_labels', False)  # Ground truth content label is given.
    self.use_variational_pi = config.get('use_variational_pi', False)

    assert (not self.use_pi_as_content) or (
        self.use_pi_as_content and self.use_real_pi_labels
    ), '`use_real_pi_labels` must be True if `use_pi_as_content` is True.'
    assert self.use_variational_pi or self.use_temporal_latent_space, ("Both "
                                                                       "`use_temporal_latent_space`"
                                                                       " and "
                                                                       "`use_variational_pi`"
                                                                       " can't "
                                                                       "be "
                                                                       "False.")
    assert not (self.use_soft_gmm and self.use_real_pi_labels
               ), 'Both `use_soft_gmm` and `use_real_pi_labels` are True.'
    assert not (self.use_real_pi_labels and len(self.input_dims) < 2
               ), 'Class labels are missing for pi_labels.'

    if (self.use_soft_gmm is False) and (self.use_temporal_latent_space is
                                         False):
      print('Warning: there is no differentiable latent space component.')
    if (self.use_variational_pi is False) and (self.use_real_pi_labels is
                                               False):
      print('Warning: there is no information source for GMM components.')

    self.num_gmm_components = config['num_gmm_components']
    self.gmm_component_size = config['gmm_component_size']

    if self.is_gmm_active:
      # Create mean and sigma variables of gmm components.
      with tf.variable_scope('gmm_latent', reuse=self.reuse):
        self.gmm_mu_vars = tf.get_variable(
            'mu',
            dtype=tf.float32,
            initializer=tf.random_uniform(
                [self.num_gmm_components, self.gmm_component_size], -1.0, 1.0))
        self.gmm_sigma_vars = self.sigma_func(
            tf.get_variable(
                'sigma',
                dtype=tf.float32,
                initializer=tf.constant_initializer(1),
                shape=[self.num_gmm_components, self.gmm_component_size]))
    else:
      self.gmm_mu_vars = self.gmm_sigma_vars = None
      self.gmm_component_size = self.input_dims[1]

    self.output_size_ = []
    if self.use_temporal_latent_space:
      self.output_size_.extend([self.z_dim] * 4)  # q_mu, q_sigma, p_mu, p_sigma

    self.output_size_.append(self.gmm_component_size)  # z_gmm
    self.output_size_.append(self.num_gmm_components)  # q_pi
    if self.use_variational_pi:
      self.output_size_.append(self.num_gmm_components)  # p_pi

    self.output_size_.extend(self.output_config['dims'])  # model outputs

  def get_gmm_components(self):
    return self.gmm_mu_vars, self.gmm_sigma_vars

  @property
  def state_size(self):
    return tuple(self.state_size_)

  @property
  def output_size(self):
    return tuple(self.output_size_)

  #
  # Functions to build graph.
  #
  def build_training_graph(self, input_, state):
    self.parse_rnn_state(state)
    self.input_layer(input_, state)
    self.input_layer_hidden()

    if self.use_temporal_latent_space:
      self.latent_p_layer()
      self.latent_q_layer()
      self.phi_z = self.phi_z_q  # Use approximate distribution in training mode.

    if self.use_variational_pi:
      self.latent_p_pi()

    self.latent_q_pi()

    self.gmm_pi = self.logits_q_pi
    if self.use_real_pi_labels:
      self.gmm_pi = self.real_pi

    self.latent_gmm()

    self.output_layer_hidden()
    self.output_layer()
    self.update_latent_rnn_layer()

  def build_sampling_graph(self, input_, state):
    self.parse_rnn_state(state)
    if self.use_real_pi_labels:  # Labels are fed for sampling.
      self.input_layer(input_, state)

    if self.use_temporal_latent_space:
      self.latent_p_layer()
      self.phi_z = self.phi_z_p

    if self.use_variational_pi:
      self.latent_p_pi()
      self.gmm_pi = self.logits_p_pi

    if self.use_real_pi_labels:
      self.gmm_pi = self.real_pi

    self.latent_gmm()

    self.output_layer_hidden()
    self.output_layer()

    # Draw a sample by using predictive distribution.
    synthetic_sample = self.draw_sample()  # This will update self.x
    self.input_layer(synthetic_sample, state)
    self.input_layer_hidden()

    if self.use_temporal_latent_space:
      self.latent_q_layer()

    self.latent_q_pi()

    self.update_latent_rnn_layer()

  def latent_gmm(self):
    if self.use_pi_as_content:
      self.gmm_z = self.gmm_pi

    elif self.use_soft_gmm:
      with tf.name_scope('latent_z_gmm'):
        eps = tf.random_normal(
            (self.gmm_pi.get_shape().as_list()[0], self.num_gmm_components,
             self.gmm_component_size),
            0.0,
            1.0,
            dtype=tf.float32)
        z = tf.add(self.gmm_mu_vars, tf.multiply(self.gmm_sigma_vars, eps))

        gmm_pi = tf.expand_dims(self.gmm_pi, axis=1)
        # [batch, 1, num_components] x [batch, num_components, component_size] -> [batch, 1, component_size]
        self.gmm_z = tf.squeeze(tf.matmul(gmm_pi, z), axis=1)
    else:
      with tf.name_scope('latent_z_gmm'):
        mixture_components = tf.expand_dims(
            tf.argmax(self.gmm_pi, axis=-1), axis=-1)
        gmm_mu = tf.gather_nd(self.gmm_mu_vars, mixture_components)
        gmm_sigma = tf.gather_nd(self.gmm_sigma_vars, mixture_components)
      # z = mu + sigma*epsilon
      self.gmm_z = self.reparametrization(
          gmm_mu, gmm_sigma, scope='latent_z_gmm')

    self.phi_z_gmm = self.phi(self.gmm_z, scope='phi_z_gmm')

  def latent_q_pi(self):
    input_ = tf.concat((self.x, self.latent_h), axis=1)
    with tf.variable_scope('latent_q_pi'):
      phi_pi = linear(
          input_,
          self.h_dim,
          self.activation_func,
          batch_norm=self.use_batch_norm)
      self.logits_q_pi = linear(
          phi_pi,
          self.num_gmm_components,
          activation_fn=None,
          batch_norm=self.use_batch_norm)

  def latent_p_pi(self):
    input_ = tf.concat((self.latent_h), axis=1)
    with tf.variable_scope('latent_p_pi'):
      phi_pi = linear(
          input_,
          self.h_dim,
          self.activation_func,
          batch_norm=self.use_batch_norm)
      self.logits_p_pi = linear(
          phi_pi,
          self.num_gmm_components,
          activation_fn=None,
          batch_norm=self.use_batch_norm)

  def latent_p_layer(self):
    input_latent_p = tf.concat((self.latent_h), axis=1)
    self.p_mu, self.p_sigma = self.latent(input_latent_p, scope='latent_z_p')

    p_z = self.reparametrization(self.p_mu, self.p_sigma, scope='z_p')
    self.phi_z_p = self.phi(p_z, scope='phi_z')

  def latent_q_layer(self):
    input_latent_q = tf.concat((self.phi_x_input, self.latent_h), axis=1)
    if self.additive_q_mu:
      q_mu_delta, self.q_sigma = self.latent(input_latent_q, scope='latent_z_q')
      self.q_mu = q_mu_delta + self.p_mu
    else:
      self.q_mu, self.q_sigma = self.latent(input_latent_q, scope='latent_z_q')

    q_z = self.reparametrization(self.q_mu, self.q_sigma, scope='z_q')
    self.phi_z_q = self.phi(q_z, scope='phi_z', reuse=True)

  def input_layer(self, input_, state):
    with tf.variable_scope('input'):
      input_components = tf.split(input_, self.input_dims, axis=1)
      #self.x = tf.nn.dropout(input_components[0], keep_prob=self.dropout_keep_prob)
      self.x = input_components[0]

      if self.use_real_pi_labels:
        self.real_pi = input_components[1]

  def input_layer_hidden(self):
    if self.input_rnn is True:
      self.phi_x_input, self.input_rnn_state = self.input_rnn_cell(
          self.x, self.input_rnn_state, scope='phi_x_input')
    else:
      self.phi_x_input = self.phi(self.x, scope='phi_x_input')

    if self.dropout_keep_prob < 1.0:
      self.phi_x_input = tf.nn.dropout(
          self.phi_x_input, keep_prob=self.dropout_keep_prob)

  def output_layer_hidden(self):
    input_list = [self.phi_z_gmm]
    if self.use_temporal_latent_space:
      input_list.append(self.phi_z)

    if self.use_latent_h_in_outputs is True:
      input_list.append(self.latent_h)

    inputs_ = tf.concat(input_list, axis=1)

    if self.output_rnn is True:
      self.phi_x_output, self.output_rnn_state = self.output_rnn_cell(
          inputs_, self.output_rnn_state, scope='phi_x_output')
    else:
      self.phi_x_output = self.phi(inputs_, scope='phi_x_output')

  def output_layer(self):
    self.output_components = {}
    for key, size, activation_func in zip(
        self.output_config['keys'], self.output_config['dims'],
        self.output_config['activation_funcs']):
      with tf.variable_scope(key):
        output_component = linear(
            self.phi_x_output,
            size,
            activation_fn=get_activation_fn(activation_func))
        self.output_components[key] = output_component

  def update_latent_rnn_layer(self):
    input_list = [self.phi_x_input, self.phi_z_gmm]

    if self.use_temporal_latent_space:
      input_list.append(self.phi_z)

    input_latent_rnn = tf.concat(input_list, axis=1)
    self.latent_rnn_output, self.latent_rnn_state = self.latent_rnn_cell(
        input_latent_rnn, self.latent_rnn_state)

  def __call__(self, input_, state, scope=None):
    with tf.variable_scope(scope or type(self).__name__, reuse=self.reuse):
      if self.is_sampling:
        self.build_sampling_graph(input_, state)
      else:
        self.build_training_graph(input_, state)

      # Prepare cell output.
      vrnn_cell_output = []
      if self.use_temporal_latent_space:
        vrnn_cell_output = [self.q_mu, self.q_sigma, self.p_mu, self.p_sigma]

      vrnn_cell_output.append(self.gmm_z)
      vrnn_cell_output.append(self.logits_q_pi)
      if self.use_variational_pi:
        vrnn_cell_output.append(self.logits_p_pi)

      for key in self.output_config['keys']:
        vrnn_cell_output.append(self.output_components[key])

      # Prepare cell state.
      vrnn_cell_state = []
      if self.input_rnn:
        vrnn_cell_state.append(self.input_rnn_state)

      vrnn_cell_state.append(self.latent_rnn_state)

      if self.output_rnn:
        vrnn_cell_state.append(self.output_rnn_state)

      return tuple(vrnn_cell_output), tuple(vrnn_cell_state)


class HandWritingVRNNGmmCell(VRNNGmmCell):
  """
    Variational RNN cell for modeling of digital handwriting.
    Training time behaviour: draws latent vectors from approximate posterior
    distribution and tries to decrease the
    discrepancy between prior and the approximate posterior distributions.
    Inference time behaviour: draws latent vectors from the prior distribution
    to synthesize a sample. This synthetic
    sample is then used to calculate approximate posterior distribution which is
    fed to RNN to update the state.
    The inputs to the forward call are not used and can be dummy.
    """

  def __init__(self, reuse, mode, config):
    super(HandWritingVRNNGmmCell, self).__init__(reuse, mode, config)

    # Use beginning-of-word (bow) labels as input,
    self.use_bow_labels = config.get('use_bow_labels', False)
    self.pen_threshold = config.get('pen_threshold', 0.4)

  # Auxiliary functions.
  def binarize(self, input_):
    """
        Transforms continuous values in [0,1] to {0,1} by applying a step
        function.
        Args:
            cont: tensor with continuous data in [0,1].

        Returns:
        """
    return tf.where(
        tf.greater(input_, tf.fill(tf.shape(input_), self.pen_threshold)),
        tf.fill(tf.shape(input_), 1.0), tf.fill(tf.shape(input_), 0.0))

  def draw_sample(self):
    # Select mu as sample.
    sample_components = [
        self.output_components['out_mu'],
        self.binarize(self.output_components['out_pen'])
    ]
    if self.use_real_pi_labels:
      sample_components.append(self.gmm_pi)
    if self.use_bow_labels:
      sample_components.append(self.bow_labels)
    return tf.concat(sample_components, axis=1)

  def input_layer(self, input_, state):
    with tf.variable_scope('input'):

      input_components = tf.split(input_, self.input_dims, axis=1)
      self.x = tf.nn.dropout(
          input_components[0], keep_prob=self.dropout_keep_prob)

      if self.use_real_pi_labels:
        self.real_pi = input_components[1]  # Character labels.

      if self.use_bow_labels:
        self.bow_labels = input_components[2]

  def latent_p_layer(self):
    input_latent_p_list = [self.latent_h]
    input_latent_p = tf.concat(input_latent_p_list, axis=1)
    self.p_mu, self.p_sigma = self.latent(input_latent_p, scope='latent_z_p')

    p_z = self.reparametrization(self.p_mu, self.p_sigma, scope='z_p')
    self.phi_z_p = self.phi(p_z, scope='phi_z')

  def latent_q_layer(self):
    input_latent_p_list = [self.phi_x_input, self.latent_h]
    input_latent_q = tf.concat(input_latent_p_list, axis=1)

    if self.additive_q_mu:
      q_mu_delta, self.q_sigma = self.latent(input_latent_q, scope='latent_z_q')
      self.q_mu = q_mu_delta + self.p_mu
    else:
      self.q_mu, self.q_sigma = self.latent(input_latent_q, scope='latent_z_q')

    q_z = self.reparametrization(self.q_mu, self.q_sigma, scope='z_q')
    self.phi_z_q = self.phi(q_z, scope='phi_z', reuse=True)

  def output_layer_hidden(self):
    input_list = [self.phi_z_gmm]
    if self.use_temporal_latent_space:
      input_list.append(self.phi_z)

    if self.use_latent_h_in_outputs is True:
      input_list.append(self.latent_h)

    if self.use_bow_labels:
      input_list.append(self.bow_labels)

    inputs_ = tf.concat(input_list, axis=1)

    if self.output_rnn is True:
      self.phi_x_output, self.output_rnn_state = self.output_rnn_cell(
          inputs_, self.output_rnn_state, scope='phi_x_output')
    else:
      self.phi_x_output = self.phi(inputs_, scope='phi_x_output')


import sys
import time
"""Vanilla variational recurrent neural network model. Assuming that model outputs are isotropic Gaussian distributions. The model is trained by using negative log-likelihood (reconstruction) and KL-divergence losses.

Model functionality is decomposed into basic functions (see build_graph
method) so that variants of the model can easily be implemented by inheriting
from the vanilla architecture.
"""


class VRNN():

  def __init__(self,
               config,
               input_op,
               input_seq_length_op,
               target_op,
               input_dims,
               target_dims,
               reuse,
               batch_size=-1,
               mode='training'):

    self.vrnn_cell_constructor = None

    self.config = config
    assert mode in ['training', 'validation', 'sampling']
    self.mode = mode
    self.is_sampling = mode == 'sampling'
    self.is_validation = mode == 'validation'
    self.is_training = mode == 'training'
    self.reuse = reuse

    self.inputs = input_op
    self.targets = target_op
    self.input_seq_length = input_seq_length_op
    self.input_dims = input_dims

    if target_op is not None or self.is_training or self.is_validation:
      self.target_dims = target_dims
      self.target_pieces = tf.split(self.targets, target_dims, axis=2)

    self.latent_size = self.config['latent_size']

    self.batch_size = config['batch_size'] if batch_size == -1 else batch_size

    # Reconstruction loss can be modeled differently. Create a key dynamically since the key is used in printing.
    self.reconstruction_loss = self.config.get('reconstruction_loss',
                                               'nll_normal')
    self.reconstruction_loss_key = 'loss_' + self.reconstruction_loss
    self.reconstruction_loss_weight = self.config.get('loss_weights', {}).get(
        'reconstruction_loss', 1)
    self.kld_loss_weight = self.config.get('loss_weights',
                                           {}).get('kld_loss', 1)

    # Function to get final loss value: average loss or summation.
    self.reduce_loss_func = get_reduce_loss_func(self.config['reduce_loss'],
                                                 self.input_seq_length)
    self.mean_sequence_func = get_reduce_loss_func('mean_per_step',
                                                   self.input_seq_length)

    # TODO: Create a dictionary just for cell arguments.
    self.vrnn_cell_args = config
    self.vrnn_cell_args['input_dims'] = self.input_dims

    # To keep track of operations. List of graph nodes that must be evaluated by session.run during training.
    self.ops_loss = {}
    # Loss ops that are used to train the model.
    self.ops_training_loss = {}
    # (Default) graph ops to be fed into session.run while evaluating the model. Note that tf_evaluate* codes assume
    # to get at least these op results.
    self.ops_evaluation = {}
    # Graph ops for scalar summaries such as average predicted variance.
    self.ops_scalar_summary = {}

  def build_graph(self):
    self.get_constructors()
    self.build_cell()
    self.build_rnn_layer()
    self.build_predictions_layer()
    self.build_loss()
    self.accumulate_loss()
    self.create_summary_plots()
    self.log_num_parameters()

  def get_constructors(self):
    """
        Enables loading project specific classes.
        """
    self.vrnn_cell_constructor = HandWritingVRNNGmmCell

  def build_cell(self):
    if self.mode == 'training' or self.mode == 'validation':
      self.cell = self.vrnn_cell_constructor(
          reuse=self.reuse, mode=self.mode, config=self.vrnn_cell_args)
    elif self.mode == 'sampling':
      self.cell = self.vrnn_cell_constructor(
          reuse=self.reuse, mode=self.mode, config=self.vrnn_cell_args)

    assert isinstance(
        self.cell,
        VRNNCell), 'Cell object must be an instance of VRNNCell for VRNN model.'

    self.initial_state = self.cell.zero_state(
        batch_size=self.batch_size, dtype=tf.float32)

  def build_rnn_layer(self):
    # Get VRNN cell output
    if self.config['use_dynamic_rnn']:
      self.outputs, self.output_state = tf.nn.dynamic_rnn(
          self.cell,
          self.inputs,
          sequence_length=self.input_seq_length,
          initial_state=self.initial_state,
          dtype=tf.float32)
    else:
      inputs_static_rnn = tf.unstack(self.inputs, axis=1)
      self.outputs_static_rnn, self.output_state = tf.nn.static_rnn(
          self.cell,
          inputs_static_rnn,
          initial_state=self.initial_state,
          sequence_length=self.input_seq_length,
          dtype=tf.float32)

      self.outputs = [
      ]  # Parse static rnn outputs and convert them into the same format with dynamic rnn.
      if self.config['use_dynamic_rnn'] is False:
        for n, name in enumerate(self.config['output']['keys']):
          x = tf.stack([o[n] for o in self.outputs_static_rnn], axis=1)
          self.outputs.append(x)

  def build_predictions_layer(self):
    # Assign rnn outputs.
    self.q_mu, self.q_sigma, self.p_mu, self.p_sigma, self.out_mu, self.out_sigma = self.outputs

    # TODO: Sampling option.
    self.output_sample = self.out_mu
    self.input_sample = self.inputs
    self.output_dim = self.output_sample.shape.as_list()[-1]

    self.ops_evaluation['output_sample'] = self.output_sample
    self.ops_evaluation['p_mu'] = self.p_mu
    self.ops_evaluation['p_sigma'] = self.p_sigma
    self.ops_evaluation['q_mu'] = self.q_mu
    self.ops_evaluation['q_sigma'] = self.q_sigma
    self.ops_evaluation['state'] = self.output_state

    num_entries = tf.cast(
        self.input_seq_length.shape.as_list()[0] *
        tf.reduce_sum(self.input_seq_length), tf.float32)
    self.ops_scalar_summary['mean_out_sigma'] = tf.reduce_sum(
        self.out_sigma) / num_entries
    self.ops_scalar_summary['mean_p_sigma'] = tf.reduce_sum(
        self.p_sigma) / num_entries
    self.ops_scalar_summary['mean_q_sigma'] = tf.reduce_sum(
        self.q_sigma) / num_entries

    # Mask for precise loss calculation.
    self.seq_loss_mask = tf.expand_dims(
        tf.sequence_mask(
            lengths=self.input_seq_length,
            maxlen=tf.reduce_max(self.input_seq_length),
            dtype=tf.float32), -1)

  def build_loss(self):
    if self.is_training or self.is_validation:
      # TODO: Use dataset object to parse the concatenated targets.
      targets_mu = self.target_pieces[0]

      if self.reconstruction_loss_key not in self.ops_loss:
        with tf.name_scope('reconstruction_loss'):
          # Gaussian log likelihood loss.
          if self.reconstruction_loss == 'nll_normal_iso':
            self.ops_loss[
                self.
                reconstruction_loss_key] = -self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * logli_normal_diag_cov(
                        targets_mu,
                        self.out_mu,
                        self.out_sigma,
                        reduce_sum=False))
          # L1 norm.
          elif self.reconstruction_loss == 'l1':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.absolute_difference(
                        targets_mu, self.out_mu, reduction='none'))
          # Mean-squared error.
          elif self.reconstruction_loss == 'mse':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.mean_squared_error(
                        targets_mu, self.out_mu, reduction='none'))
          else:
            raise Exception('Undefined loss.')

      if 'loss_kld' not in self.ops_loss:
        with tf.name_scope('kld_loss'):
          self.ops_loss[
              'loss_kld'] = self.kld_loss_weight * self.reduce_loss_func(
                  self.seq_loss_mask * kld_normal_isotropic(
                      self.q_mu,
                      self.q_sigma,
                      self.p_mu,
                      self.p_sigma,
                      reduce_sum=False))

  def accumulate_loss(self):
    # Accumulate losses to create training optimization.
    # Model.loss is used by the optimization function.
    self.loss = 0
    for _, loss_op in self.ops_loss.items():
      self.loss += loss_op
    self.ops_loss['total_loss'] = self.loss

  def log_loss(self, eval_loss, step=0, epoch=0, time_elapsed=None, prefix=''):
    loss_format = prefix + '{}/{} \t Total: {:.4f} \t'
    loss_entries = [step, epoch, eval_loss['total_loss']]

    for loss_key in sorted(eval_loss.keys()):
      if loss_key != 'total_loss':
        loss_format += '{}: {:.4f} \t'
        loss_entries.append(loss_key)
        loss_entries.append(eval_loss[loss_key])

    if time_elapsed is not None:
      print(
          loss_format.format(*loss_entries) +
          'time/batch = {:.3f}'.format(time_elapsed))
    else:
      print(loss_format.format(*loss_entries))

  def log_num_parameters(self):
    num_param = 0
    for v in tf.global_variables():
      num_param += np.prod(v.get_shape().as_list())

    self.num_parameters = num_param
    print('# of parameters: ' + str(num_param))

  def create_summary_plots(self):
    """
        Creates scalar summaries for loss plots. Iterates through `ops_loss`
        member and create a summary entry.
        If the model is in `validation` mode, then we follow a different
        strategy. In order to have a consistent
        validation report over iterations, we first collect model performance on
        every validation mini-batch and then
        report the average loss. Due to tensorflow's lack of loss averaging ops,
        we need to create placeholders per loss
        to pass the average loss.
        Returns:
        """
    if self.is_training:
      for loss_name, loss_op in self.ops_loss.items():
        tf.summary.scalar(
            loss_name,
            loss_op,
            collections=[self.mode + '_summary_plot', self.mode + '_loss'])

    elif self.is_validation:  # Validation: first accumulate losses and then log them.
      self.container_loss = {}
      self.container_loss_placeholders = {}
      self.container_validation_feed_dict = {}
      self.validation_summary_num_runs = 0

      for loss_name, _ in self.ops_loss.items():
        self.container_loss[loss_name] = 0
        self.container_loss_placeholders[loss_name] = tf.placeholder(
            tf.float32, shape=[])
        tf.summary.scalar(
            loss_name,
            self.container_loss_placeholders[loss_name],
            collections=[self.mode + '_summary_plot', self.mode + '_loss'])
        self.container_validation_feed_dict[
            self.container_loss_placeholders[loss_name]] = 0

    for summary_name, scalar_summary_op in self.ops_scalar_summary.items():
      tf.summary.scalar(
          summary_name,
          scalar_summary_op,
          collections=[
              self.mode + '_summary_plot', self.mode + '_scalar_summary'
          ])

    # Create summaries to visualize distribution of latent variables.
    if self.config['tensorboard_verbose'] > 0:
      tf.summary.histogram(
          'p_mu',
          self.p_mu,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])
      tf.summary.histogram(
          'p_sigma',
          self.p_sigma,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])
      tf.summary.histogram(
          'q_mu',
          self.q_mu,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])
      tf.summary.histogram(
          'q_sigma',
          self.q_sigma,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])
      tf.summary.histogram(
          'out_mu',
          self.out_mu,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])
      tf.summary.histogram(
          'out_sigma',
          self.out_sigma,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])

    self.loss_summary = tf.summary.merge_all(self.mode + '_summary_plot')

  ########################################
  # Summary methods for validation mode.
  ########################################
  def update_validation_loss(self, loss_evaluated):
    self.validation_summary_num_runs += 1
    for loss_name, loss_value in loss_evaluated.items():
      self.container_loss[loss_name] += loss_value

  def reset_validation_loss(self):
    for loss_name, loss_value in self.container_loss.items():
      self.container_loss[loss_name] = 0

  def get_validation_summary(self, session):
    for loss_name, loss_pl in self.container_loss_placeholders.items():
      self.container_loss[loss_name] /= self.validation_summary_num_runs
      self.container_validation_feed_dict[loss_pl] = self.container_loss[
          loss_name]
    self.validation_summary_num_runs = 0

    # return self.container_validation_feed_dict, self.container_loss
    valid_summary = session.run(self.loss_summary,
                                self.container_validation_feed_dict)
    return valid_summary, self.container_loss

  ########################################
  # Evaluation methods.
  ########################################

  def reconstruct_given_sample(self,
                               session,
                               inputs,
                               targets=None,
                               ops_eval=None):
    """
        Reconstructs a given sample.
        Args:
            session:
            inputs: input tensor of size (batch_size, sequence_length,
              input_size).
            targets: to calculate model loss. if None, then loss is not
              calculated.
            ops_eval: ops to be evaluated by the model.

        Returns:
        """
    model_inputs = np.expand_dims(
        inputs, axis=0) if inputs.ndim == 2 else inputs
    model_targets = np.expand_dims(
        targets,
        axis=0) if (targets is not None) and (targets.ndim == 2) else targets
    eval_op_list = []
    if ops_eval is None:
      ops_eval = self.ops_evaluation
    eval_op_list.append(ops_eval)

    feed = {
        self.inputs: model_inputs,
        self.input_seq_length: np.ones(1) * model_inputs.shape[1]
    }

    if model_targets is not None:
      feed[self.targets] = model_targets
      eval_op_list.append(self.ops_loss)

    eval_results = session.run(eval_op_list, feed)

    return eval_results

  def sample_unbiased(self, session, seq_len=500, ops_eval=None, **kwargs):
    """
        Generates new samples randomly. Note that this function works only if
        the model is created in "sampling" mode.
        Args:
            **kwargs:
            session:
            seq_len: # of frames.
            ops_eval: ops to be evaluated by the model.

        Returns:
        """
    dummy_x = np.zeros((self.batch_size, seq_len, sum(self.input_dims)))
    prev_state = session.run(
        self.cell.zero_state(batch_size=self.batch_size, dtype=tf.float32))

    eval_op_list = []
    if ops_eval is None:
      ops_eval = self.ops_evaluation
    eval_op_list.append(ops_eval)

    model_inputs = dummy_x
    feed = {
        self.inputs: model_inputs,
        self.input_seq_length: np.ones(1) * model_inputs.shape[1],
        self.initial_state: prev_state
    }

    eval_results = session.run(eval_op_list, feed)
    return eval_results

  def sample_biased(self,
                    session,
                    seq_len,
                    prev_state,
                    prev_sample=None,
                    ops_eval=None,
                    **kwargs):
    """
        Initializes the model by using state of a real sample.
        Args:
            session:
            seq_len:
            prev_state: rnn state to be used as reference.
            prev_sample: sample that is used to bias the model and generate
              prev_state. If not None, then it is concatenated with the
              synthetic sample for visualization.
            ops_eval: ops to be evaluated by the model.

        Returns:
        """

    ref_len = 0
    if prev_sample is not None:
      prev_sample = np.expand_dims(
          prev_sample, axis=0) if prev_sample.ndim == 2 else prev_sample
      ref_len = prev_sample.shape[1]

      output_sample_concatenated = np.zeros(
          (self.batch_size, seq_len, self.output_dim), dtype=np.float32)
      output_sample_concatenated[:, :
                                 ref_len] = prev_sample[:, :
                                                        ref_len]  # ref_sample_reconstructed

    seq_len = seq_len - ref_len
    dummy_x = np.zeros((self.batch_size, seq_len, sum(self.input_dims)))

    eval_op_list = []
    if ops_eval is None:
      ops_eval = self.ops_evaluation
    eval_op_list.append(ops_eval)

    model_inputs = dummy_x
    feed = {
        self.inputs: model_inputs,
        self.input_seq_length: np.ones(1) * model_inputs.shape[1],
        self.initial_state: prev_state
    }

    eval_results = session.run(eval_op_list, feed)

    if prev_sample is not None:
      output_sample_concatenated[:, ref_len:] = eval_results[0]['output_sample']
      eval_results[0]['output_sample'] = output_sample_concatenated

    return eval_results


class VRNNGMM(VRNN):

  def __init__(self,
               config,
               input_op,
               input_seq_length_op,
               target_op,
               input_dims,
               target_dims,
               reuse,
               batch_size=-1,
               mode='training'):
    VRNN.__init__(
        self,
        config,
        input_op,
        input_seq_length_op,
        target_op,
        input_dims,
        target_dims,
        reuse,
        batch_size=batch_size,
        mode=mode)

    # VRNNCellGMM configuration.
    self.use_temporal_latent_space = config.get('use_temporal_latent_space',
                                                True)
    self.use_variational_pi = config.get('use_variational_pi', False)
    self.use_real_pi_labels = config.get('use_real_pi_labels', False)
    self.use_soft_gmm = config.get('use_soft_gmm', False)
    self.is_gmm_active = not (config.get('use_pi_as_content', False))

    self.num_gmm_components = config['num_gmm_components']
    self.gmm_component_size = config['gmm_component_size']

    self.kld_loss_pi_weight = self.config.get('loss_weights',
                                              {}).get('kld_loss_pi', 1)
    self.gmm_sigma_regularizer_weight = self.config.get('loss_weights', {}).get(
        'gmm_sigma_regularizer', None)
    self.classification_loss_weight = self.config.get('loss_weights', {}).get(
        'classification_loss', None)
    self.pi_entropy_loss_weight = self.config.get('loss_weights', {}).get(
        'pi_entropy_loss', None)

    self.use_classification_loss = False if self.classification_loss_weight is None else True
    self.use_gmm_sigma_loss = False if self.gmm_sigma_regularizer_weight is None else True
    self.use_pi_entropy_loss = False if self.pi_entropy_loss_weight is None else True

    # Sanity Check
    if target_op is not None or self.is_training or self.is_validation:
      assert not (self.use_real_pi_labels and len(self.target_dims) < 2
                 ), 'Real labels are not provided: rank(target_dims) < 2.'
      assert not (
          self.use_classification_loss and len(self.target_dims) < 2
      ), ('Real labels are not provided for classification loss: '
          'rank(target_dims) < 2.')

  def build_cell(self):
    if self.mode == 'training' or self.mode == 'validation':
      self.cell = self.vrnn_cell_constructor(
          reuse=self.reuse, mode=self.mode, config=self.vrnn_cell_args)
    elif self.mode == 'sampling':
      self.cell = self.vrnn_cell_constructor(
          reuse=self.reuse, mode=self.mode, config=self.vrnn_cell_args)

    assert isinstance(
        self.cell, VRNNGmmCell
    ), 'Cell object must be an instance of VRNNCellGMM for VRNNGMM model.'

    # GMM components are 2D: [# components, component size]
    if self.is_gmm_active:
      self.gmm_mu, self.gmm_sigma = self.cell.get_gmm_components()

    self.initial_state = self.cell.zero_state(
        batch_size=self.batch_size, dtype=tf.float32)

  def build_predictions_layer(self):
    # Assign rnn outputs.
    if self.use_temporal_latent_space and self.use_variational_pi:
      self.q_mu, self.q_sigma, self.p_mu, self.p_sigma, self.gmm_z, self.q_pi, self.p_pi, self.out_mu, self.out_sigma = self.outputs
    elif self.use_temporal_latent_space:
      self.q_mu, self.q_sigma, self.p_mu, self.p_sigma, self.gmm_z, self.q_pi, self.out_mu, self.out_sigma = self.outputs
    elif self.use_variational_pi:
      self.gmm_z, self.q_pi, self.p_pi, self.out_mu, self.out_sigma = self.outputs

    # TODO: Sampling option.
    self.output_sample = self.out_mu
    self.input_sample = self.inputs
    self.output_dim = self.output_sample.shape.as_list()[-1]

    self.ops_evaluation['output_sample'] = self.output_sample
    if self.use_temporal_latent_space:
      self.ops_evaluation['p_mu'] = self.p_mu
      self.ops_evaluation['p_sigma'] = self.p_sigma
      self.ops_evaluation['q_mu'] = self.q_mu
      self.ops_evaluation['q_sigma'] = self.q_sigma
    if self.use_variational_pi:
      self.ops_evaluation['p_pi'] = tf.nn.softmax(self.p_pi, dim=-1)
    self.ops_evaluation['q_pi'] = tf.nn.softmax(self.q_pi, dim=-1)
    self.ops_evaluation['gmm_z'] = self.gmm_z
    self.ops_evaluation['state'] = self.output_state

    num_entries = tf.cast(
        self.input_seq_length.shape.as_list()[0] *
        tf.reduce_sum(self.input_seq_length), tf.float32)
    self.ops_scalar_summary['mean_out_sigma'] = tf.reduce_sum(
        self.out_sigma) / num_entries
    self.ops_scalar_summary['mean_p_sigma'] = tf.reduce_sum(
        self.p_sigma) / num_entries
    self.ops_scalar_summary['mean_q_sigma'] = tf.reduce_sum(
        self.q_sigma) / num_entries

    # Mask for precise loss calculation.
    self.seq_loss_mask = tf.expand_dims(
        tf.sequence_mask(
            lengths=self.input_seq_length,
            maxlen=tf.reduce_max(self.input_seq_length),
            dtype=tf.float32), -1)

  def build_loss(self):
    if self.is_training or self.is_validation:
      # TODO: Use dataset object to parse the concatenated targets.
      targets_mu = self.target_pieces[0]

      if self.reconstruction_loss_key not in self.ops_loss:
        with tf.name_scope('reconstruction_loss'):
          # Gaussian log likelihood loss.
          if self.reconstruction_loss == 'nll_normal_iso':
            self.ops_loss[
                self.
                reconstruction_loss_key] = -self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * logli_normal_diag_cov(
                        targets_mu,
                        self.out_mu,
                        self.out_sigma,
                        reduce_sum=False))
          # L1 norm.
          elif self.reconstruction_loss == 'l1':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.absolute_difference(
                        targets_mu, self.out_mu, reduction='none'))
          # Mean-squared error.
          elif self.reconstruction_loss == 'mse':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.mean_squared_error(
                        targets_mu, self.out_mu, reduction='none'))
          else:
            raise Exception('Undefined loss.')

      if self.use_temporal_latent_space and not 'loss_kld' in self.ops_loss:
        with tf.name_scope('kld_loss'):
          self.ops_loss[
              'loss_kld'] = self.kld_loss_weight * self.reduce_loss_func(
                  self.seq_loss_mask * kld_normal_isotropic(
                      self.q_mu,
                      self.q_sigma,
                      self.p_mu,
                      self.p_sigma,
                      reduce_sum=False))

      flat_q_pi = tf.reshape(self.q_pi, [-1, self.num_gmm_components])
      self.dist_q = tf.compat.v1.distributions.Categorical(logits=flat_q_pi)

      if self.use_variational_pi and not 'loss_kld_pi' in self.ops_loss:
        with tf.name_scope('kld_pi_loss'):
          flat_p_pi = tf.reshape(self.p_pi, [-1, self.num_gmm_components])
          self.dist_p = tf.compat.v1.distributions.Categorical(logits=flat_p_pi)

          flat_kld_cat_loss = tf.compat.v1.distributions.kl_divergence(
              distribution_a=self.dist_q, distribution_b=self.dist_p)
          temporal_kld_cat_loss = tf.reshape(flat_kld_cat_loss,
                                             [self.batch_size, -1, 1])
          self.ops_loss[
              'loss_kld_pi'] = self.kld_loss_pi_weight * self.reduce_loss_func(
                  self.seq_loss_mask * temporal_kld_cat_loss)

      if self.use_pi_entropy_loss and not 'loss_entropy_pi' in self.ops_loss:
        self.ops_loss[
            'loss_entropy_pi'] = self.reconstruction_loss_weight * self.reduce_loss_func(
                self.seq_loss_mask *
                tf.reshape(self.dist_q.entropy(), [self.batch_size, -1, 1]))

      if self.use_classification_loss and not 'loss_classification' in self.ops_loss:
        targets_categorical_labels = self.target_pieces[1]
        # Use GMM latent space probabilities as class predictions.
        self.label_predictions = self.q_pi

        with tf.name_scope('classification_loss'):
          prediction_size = targets_categorical_labels.get_shape().as_list()[-1]
          flat_labels = tf.reshape(targets_categorical_labels,
                                   [-1, prediction_size])
          flat_predictions = tf.reshape(self.label_predictions,
                                        [-1, prediction_size])

          flat_char_classification_loss = tf.losses.softmax_cross_entropy(
              flat_labels, flat_predictions, reduction='none')
          temporal_char_classification_loss = tf.reshape(
              flat_char_classification_loss, [self.batch_size, -1, 1])
          self.ops_loss[
              'loss_classification'] = self.classification_loss_weight * self.reduce_loss_func(
                  self.seq_loss_mask * temporal_char_classification_loss)

      if self.is_gmm_active and self.use_gmm_sigma_loss and not 'loss_gmm_sigma' in self.ops_loss:
        with tf.name_scope('gmm_sigma_loss'):
          self.ops_loss['loss_gmm_sigma'] = tf.reduce_mean(
              tf.square(1 - self.gmm_sigma))

  def create_summary_plots(self):
    """
        Creates scalar summaries for loss plots. Iterates through `ops_loss`
        member and create a summary entry.
        If the model is in `validation` mode, then we follow a different
        strategy. In order to have a consistent
        validation report over iterations, we first collect model performance on
        every validation mini-batch and then
        report the average loss. Due to tensorflow's lack of loss averaging ops,
        we need to create placeholders per loss
        to pass the average loss.
        Returns:
        """
    if self.is_training:
      for loss_name, loss_op in self.ops_loss.items():
        tf.summary.scalar(
            loss_name,
            loss_op,
            collections=[self.mode + '_summary_plot', self.mode + '_loss'])

    elif self.is_validation:  # Validation: first accumulate losses and then log them.
      self.container_loss = {}
      self.container_loss_placeholders = {}
      self.container_validation_feed_dict = {}
      self.validation_summary_num_runs = 0

      for loss_name, _ in self.ops_loss.items():
        self.container_loss[loss_name] = 0
        self.container_loss_placeholders[loss_name] = tf.placeholder(
            tf.float32, shape=[])
        tf.summary.scalar(
            loss_name,
            self.container_loss_placeholders[loss_name],
            collections=[self.mode + '_summary_plot', self.mode + '_loss'])
        self.container_validation_feed_dict[
            self.container_loss_placeholders[loss_name]] = 0

    for summary_name, scalar_summary_op in self.ops_scalar_summary.items():
      tf.summary.scalar(
          summary_name,
          scalar_summary_op,
          collections=[
              self.mode + '_summary_plot', self.mode + '_scalar_summary'
          ])

    # Create summaries to visualize distribution of latent variables.
    if self.config['tensorboard_verbose'] > 0:
      if self.is_gmm_active:
        tf.summary.histogram(
            'gmm_mu',
            self.gmm_mu,
            collections=[
                self.mode + '_summary_plot', self.mode + '_stochastic_variables'
            ])
        tf.summary.histogram(
            'gmm_sigma',
            self.gmm_sigma,
            collections=[
                self.mode + '_summary_plot', self.mode + '_stochastic_variables'
            ])
      if self.use_temporal_latent_space:
        tf.summary.histogram(
            'p_mu',
            self.p_mu,
            collections=[
                self.mode + '_summary_plot', self.mode + '_stochastic_variables'
            ])
        tf.summary.histogram(
            'p_sigma',
            self.p_sigma,
            collections=[
                self.mode + '_summary_plot', self.mode + '_stochastic_variables'
            ])
        tf.summary.histogram(
            'q_mu',
            self.q_mu,
            collections=[
                self.mode + '_summary_plot', self.mode + '_stochastic_variables'
            ])
        tf.summary.histogram(
            'q_sigma',
            self.q_sigma,
            collections=[
                self.mode + '_summary_plot', self.mode + '_stochastic_variables'
            ])
      if self.use_variational_pi:
        tf.summary.histogram(
            'p_pi',
            tf.nn.softmax(self.p_pi),
            collections=[
                self.mode + '_summary_plot', self.mode + '_stochastic_variables'
            ])

      tf.summary.histogram(
          'q_pi',
          tf.nn.softmax(self.q_pi),
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])
      tf.summary.histogram(
          'out_mu',
          self.out_mu,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])
      tf.summary.histogram(
          'out_sigma',
          self.out_sigma,
          collections=[
              self.mode + '_summary_plot', self.mode + '_stochastic_variables'
          ])

    self.loss_summary = tf.summary.merge_all(self.mode + '_summary_plot')

  def evaluate_gmm_latent_space(self, session):
    gmm_mus, gmm_sigmas = session.run([self.gmm_mu, self.gmm_sigma])
    return gmm_mus, gmm_sigmas


import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import InsetPosition
import PIL
"""Functions to save plots, matrices as image.
"""


def plot_and_get_image(plot_data, fig_height=8, fig_width=12, axis_off=False):
  fig = plt.figure()
  fig.set_figheight(fig_height)
  fig.set_figwidth(fig_width)
  plt.plot(plot_data)
  if axis_off:
    plt.axis('off')

  img = fig_to_img(fig)
  plt.close(fig)
  return img


def plot_matrix_and_get_image(plot_data,
                              fig_height=8,
                              fig_width=12,
                              axis_off=False,
                              colormap='jet'):
  fig = plt.figure()
  fig.set_figheight(fig_height)
  fig.set_figwidth(fig_width)
  plt.matshow(plot_data, fig.number)

  if fig_height < fig_width:
    plt.colorbar(orientation='horizontal')
  else:
    plt.colorbar(orientation='vertical')

  plt.set_cmap(colormap)
  if axis_off:
    plt.axis('off')

  img = fig_to_img(fig)
  plt.close(fig)
  return img


def plot_matrices(plot_data,
                  title_data={},
                  row_colorbar=True,
                  fig_height=8,
                  fig_width=12,
                  show_plot=False):
  """
    Args:
        plot_data: A dictionary with positional index keys such as "00", "01",
          "10" and "11" where each entry is a two dimensional matrix.
        title_data: A dictionary with positional index keys such as "00", "01",
          "10" and "11" where each entry is plot title.
        row_colorbar: Use a common colorbar for two matrices in the same row. If
          False, each matrix has its own colorbar.
        fig_height:
        fig_width:
        colormap:

    Returns:
    """
  num_latent_neurons, seq_len = plot_data['00'].shape
  aspect_ratio = max(round((seq_len / num_latent_neurons) / 2), 1)
  yticks = np.int32(
      np.linspace(
          start=0, stop=num_latent_neurons - 1, num=min(16,
                                                        num_latent_neurons)))

  nrows = 1
  if '10' in plot_data.keys():
    nrows = 2

  if row_colorbar:
    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=3,
        figsize=(fig_width, fig_height),
        gridspec_kw={'width_ratios': [1, 1, 0.05]})
  else:
    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=4,
        figsize=(fig_width, fig_height),
        gridspec_kw={'width_ratios': [1, 0.05, 1, 0.05]})

  if nrows == 1:
    axes = np.expand_dims(axes, axis=0)

  plt.subplots_adjust(
      left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
  fig.tight_layout()
  plt.setp(axes, yticks=yticks)

  # first row: 00
  if row_colorbar:
    min_value = min(plot_data['00'].min(), plot_data['01'].min())
    max_value = max(plot_data['00'].max(), plot_data['01'].max())

    axes[0, 0].imshow(
        plot_data['00'], vmin=min_value, vmax=max_value, aspect=aspect_ratio)
    axes[0, 0].set_title(title_data.get('00', ''))

    colorbar_ref = axes[0, 1].imshow(
        plot_data['01'], vmin=min_value, vmax=max_value, aspect=aspect_ratio)
    axes[0, 1].set_title(title_data.get('01', ''))

    ip = InsetPosition(axes[0, 1], [1.05, 0, 0.05, 1])
    axes[0, 2].set_axes_locator(ip)
    fig.colorbar(colorbar_ref, cax=axes[0, 2], ax=[axes[0, 0], axes[0, 1]])
  else:
    colorbar_ref = axes[0, 0].imshow(
        plot_data['00'],
        vmin=plot_data['00'].min(),
        vmax=plot_data['00'].max(),
        aspect=aspect_ratio)
    axes[0, 0].set_title(title_data.get('00', ''))
    ip = InsetPosition(axes[0, 0], [1.05, 0, 0.05, 1])
    axes[0, 1].set_axes_locator(ip)
    fig.colorbar(colorbar_ref, cax=axes[0, 1], ax=axes[0, 0])

    colorbar_ref = axes[0, 2].imshow(
        plot_data['01'],
        vmin=plot_data['01'].min(),
        vmax=plot_data['01'].max(),
        aspect=aspect_ratio)
    axes[0, 2].set_title(title_data.get('01', ''))
    ip = InsetPosition(axes[0, 2], [1.05, 0, 0.05, 1])
    axes[0, 3].set_axes_locator(ip)
    fig.colorbar(colorbar_ref, cax=axes[0, 3], ax=axes[0, 2])

  img = fig_to_img(fig)
  if show_plot:
    plt.show()
  else:
    plt.close(fig)
  return img


def plot_latent_variables(plot_data,
                          fig_height=8,
                          fig_width=12,
                          show_plot=False):
  """
    Args:
        plot_data: a dictionary with keys "q_mu", "q_sigma", "p_mu" and
          "p_sigma" where each field is a two dimensional matrix with size of
          (num_latent_neurons, seq_len)
        fig_height:
        fig_width:
        colormap:

    Returns:
    """
  num_latent_neurons, seq_len = plot_data['p_mu'].shape
  aspect_ratio = max(round((seq_len / num_latent_neurons) / 2), 1)
  yticks = np.int32(
      np.linspace(
          start=0, stop=num_latent_neurons - 1, num=min(16,
                                                        num_latent_neurons)))

  fig, axes = plt.subplots(
      nrows=2,
      ncols=3,
      figsize=(fig_width, fig_height),
      gridspec_kw={'width_ratios': [1, 1, 0.05]})
  plt.subplots_adjust(
      left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
  fig.tight_layout()
  plt.setp(axes, yticks=yticks)

  # mu plots
  mu_min = min(plot_data['q_mu'].min(), plot_data['p_mu'].min())
  mu_max = max(plot_data['q_mu'].max(), plot_data['p_mu'].max())

  axes[0, 0].imshow(
      plot_data['q_mu'], vmin=mu_min, vmax=mu_max, aspect=aspect_ratio)
  axes[0, 0].set_title('q_mu')
  #axes[0, 0].set_yticks(range(plot_data['q_mu'].shape[0]))

  im_p_mu = axes[0, 1].imshow(
      plot_data['p_mu'], vmin=mu_min, vmax=mu_max, aspect=aspect_ratio)
  axes[0, 1].set_title('p_mu')
  #axes[0, 1].axis('off')

  ip = InsetPosition(axes[0, 1], [1.05, 0, 0.05, 1])
  axes[0, 2].set_axes_locator(ip)
  fig.colorbar(im_p_mu, cax=axes[0, 2], ax=[axes[0, 0], axes[0, 1]])

  # sigma plots
  mu_min = min(plot_data['q_sigma'].min(), plot_data['p_sigma'].min())
  mu_max = max(plot_data['q_sigma'].max(), plot_data['p_sigma'].max())

  axes[1, 0].imshow(
      plot_data['q_sigma'], vmin=mu_min, vmax=mu_max, aspect=aspect_ratio)
  axes[1, 0].set_title('q_sigma')
  #axes[1, 0].set_yticks(range(plot_data['q_mu'].shape[0]))

  im_p_sigma = axes[1, 1].imshow(
      plot_data['p_sigma'], vmin=mu_min, vmax=mu_max, aspect=aspect_ratio)
  axes[1, 1].set_title('p_sigma')
  #axes[1, 1].axis('off')

  ip = InsetPosition(axes[1, 1], [1.05, 0, 0.05, 1])
  axes[1, 2].set_axes_locator(ip)
  fig.colorbar(im_p_sigma, cax=axes[1, 2], ax=[axes[1, 0], axes[1, 1]])

  img = fig_to_img(fig)
  if show_plot:
    plt.show()
  else:
    plt.close(fig)
  return img


def plot_latent_categorical_variables(plot_data,
                                      fig_height=8,
                                      fig_width=12,
                                      show_plot=False):
  """
    Args:
        plot_data: a dictionary with keys "q_pi" and "p_pi" where each field is
          a two dimensional matrix with size of (num_latent_neurons, seq_len)
        fig_height:
        fig_width:
        colormap:

    Returns:
    """
  num_latent_neurons, seq_len = plot_data['q_pi'].shape
  aspect_ratio = max(round((seq_len / num_latent_neurons) / 2), 1)
  yticks = np.int32(
      np.linspace(
          start=0, stop=num_latent_neurons - 1, num=min(16,
                                                        num_latent_neurons)))

  fig, axes = plt.subplots(
      nrows=1,
      ncols=3,
      figsize=(fig_width, fig_height),
      gridspec_kw={'width_ratios': [1, 1, 0.05]})
  plt.subplots_adjust(
      left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
  fig.tight_layout()
  plt.setp(axes, yticks=yticks)

  min_val = min(plot_data['q_pi'].min(), plot_data['p_pi'].min())
  max_val = max(plot_data['q_pi'].max(), plot_data['p_pi'].max())

  axes[0].imshow(
      plot_data['q_pi'], vmin=min_val, vmax=max_val, aspect=aspect_ratio)
  axes[0].set_title('q_pi')
  #axes[0, 0].set_yticks(range(plot_data['q_mu'].shape[0]))

  im_p_mu = axes[1].imshow(
      plot_data['p_pi'], vmin=min_val, vmax=max_val, aspect=aspect_ratio)
  axes[1].set_title('p_pi')
  #axes[1].axis('off')

  ip = InsetPosition(axes[1], [1.05, 0, 0.05, 1])
  axes[2].set_axes_locator(ip)
  fig.colorbar(im_p_mu, cax=axes[2], ax=[axes[0], axes[1]])

  img = fig_to_img(fig)
  if show_plot:
    plt.show()
  else:
    plt.close(fig)
  return img


def fig_to_img(fig):
  """
    Convert a Matplotlib figure to an image in numpy array format.
    Args:
        fig:a matplotlib figure

    Returns:
    """
  fig.canvas.draw()

  # Now we can save it to a numpy array.
  img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))

  return img


def fig_to_img_pil(fig):
  """
    Convert a Matplotlib figure to a PIL Image in RGBA format and return as
    numpy array.
    Args:
        fig: a matplotlib figure
    Returns: a numpy array of Python Imaging Library ( PIL ) image.
  """
  # draw the renderer
  fig.canvas.draw()

  # Get the RGBA buffer from the figure
  w, h = fig.canvas.get_width_height()
  buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
  buf.shape = (w, h, 4)

  # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
  buf = np.roll(buf, 3, axis=2)
  w, h, d = buf.shape

  return np.array(PIL.Image.frombytes('RGBA', (w, h), buf.tostring()))


class HandwritingVRNNModel(VRNN):

  def __init__(self,
               config,
               input_op,
               input_seq_length_op,
               target_op,
               input_dims,
               target_dims,
               reuse,
               data_processor,
               batch_size=-1,
               mode='training'):
    VRNN.__init__(
        self,
        config,
        input_op,
        input_seq_length_op,
        target_op,
        input_dims,
        target_dims,
        reuse,
        batch_size=batch_size,
        mode=mode)

    self.dataset_obj = data_processor
    self.pen_loss_weight = self.config.get('loss_weights',
                                           {}).get('pen_loss', 1)

    # TODO: Create a dictionary just for cell arguments.
    self.vrnn_cell_args = config
    self.vrnn_cell_args['input_dims'] = self.input_dims

    # See `create_image_summary` method for details.
    self.img_summary_entries = []
    self.ops_img_summary = {}
    self.use_img_summary = False

  def get_constructors(self):
    self.vrnn_cell_constructor = HandWritingVRNNGmmCell

  def build_predictions_layer(self):
    # Assign rnn outputs.
    self.q_mu, self.q_sigma, self.p_mu, self.p_sigma, self.out_mu, self.out_sigma, self.out_rho, self.out_pen = self.outputs

    # For analysis.
    self.norm_p_mu = tf.norm(self.p_mu, axis=-1)
    self.norm_p_sigma = tf.norm(self.p_sigma, axis=-1)
    self.norm_q_mu = tf.norm(self.q_mu, axis=-1)
    self.norm_q_sigma = tf.norm(self.q_sigma, axis=-1)
    self.norm_out_mu = tf.norm(self.out_mu, axis=-1)
    self.norm_out_sigma = tf.norm(self.out_sigma, axis=-1)

    # TODO: Sampling option.
    self.output_sample = tf.concat(
        [self.out_mu, tf.round(self.out_pen)], axis=2)
    self.input_sample = self.inputs
    self.output_dim = self.output_sample.shape.as_list()[-1]

    self.ops_evaluation['output_sample'] = self.output_sample
    self.ops_evaluation['p_mu'] = self.p_mu
    self.ops_evaluation['p_sigma'] = self.p_sigma
    self.ops_evaluation['q_mu'] = self.q_mu
    self.ops_evaluation['q_sigma'] = self.q_sigma
    self.ops_evaluation['state'] = self.output_state

    # In case we want to draw samples from output distribution instead of using mean.
    self.ops_evaluation['out_mu'] = self.out_mu
    self.ops_evaluation['out_sigma'] = self.out_sigma
    self.ops_evaluation['out_rho'] = self.out_rho
    self.ops_evaluation['out_pen'] = self.out_pen

    # Mask for precise loss calculation.
    self.seq_loss_mask = tf.expand_dims(
        tf.sequence_mask(
            lengths=self.input_seq_length,
            maxlen=tf.reduce_max(self.input_seq_length),
            dtype=tf.float32), -1)

  def build_loss(self):
    if self.is_training or self.is_validation:
      targets_mu = self.target_pieces[0]
      targets_pen = self.target_pieces[1]

      if self.reconstruction_loss_key not in self.ops_loss:
        with tf.name_scope('reconstruction_loss'):
          # Gaussian log likelihood loss (bivariate)
          if self.reconstruction_loss == 'nll_normal_bi':
            self.ops_loss[
                self.
                reconstruction_loss_key] = -self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * logli_normal_bivariate(
                        targets_mu,
                        self.out_mu,
                        self.out_sigma,
                        self.out_rho,
                        reduce_sum=False))
          # Gaussian log likelihood loss (diagonal covariance)
          elif self.reconstruction_loss == 'nll_normal_diag':
            self.ops_loss[
                self.
                reconstruction_loss_key] = -self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * logli_normal_diag_cov(
                        targets_mu,
                        self.out_mu,
                        self.out_sigma,
                        reduce_sum=False))
          # L1 norm.
          elif self.reconstruction_loss == 'l1':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.absolute_difference(
                        targets_mu, self.out_mu, reduction='none'))
          # Mean-squared error.
          elif self.reconstruction_loss == 'mse':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.mean_squared_error(
                        targets_mu, self.out_mu, reduction='none'))
          else:
            raise Exception('Undefined loss.')

      if 'loss_pen' not in self.ops_loss:
        with tf.name_scope('pen_reconstruction_loss'):
          # Bernoulli loss for pen information.
          self.ops_loss[
              'loss_pen'] = -self.pen_loss_weight * self.reduce_loss_func(
                  self.seq_loss_mask *
                  logli_bernoulli(targets_pen, self.out_pen, reduce_sum=False))

      VRNN.build_loss(self)

  def create_image_summary(self,
                           undo_preprocessing_func,
                           img_stroke_shape=(1, 120, 1200, 1),
                           img_norm_shape=(1, 800, 1200, 3)):
    """
        Creates placeholder and summary operations for image summaries. Supports
        two types of summaries:
        (1) stroke images.
        (2) image visualization of plots for a given sample. Note that this is
        required to visualize model performance
            on a test sample over training.
        In order to add a new type one should create an `img_entry` (see
        `stroke_img_entry` and `norm_plot_img_entry`)
        and register graph nodes as well as a post-processing function (see
        `post_processing_func` field).
        When `get_image_summary` method is called, for every registered `op`
        first evaluated results are converted into
        images and stored in containers (`container_img`). Then a summary object
        is created by passing these containers
        to tf.placeholders (`container_img_placeholders`).
        Args:
            session: tensorflow session.
            writer: summary writer.
            undo_preprocessing_func: function to undo normalization and
              preprocessing operations on model outputs.
            img_stroke_shape: shape of stroke images.
            img_norm_shape: shape of norm plot images.

        Returns:
        """

    # Post-processing functions for images.
    def norm_plot_img_func(img_data):
      return plot_and_get_image(img_data, axis_off=False)

    def stroke_img_func(img_data):
      pass

    if self.use_img_summary:
      # Make a separation between different types of images and provide corresponding functionality.
      stroke_img_entry = {}
      stroke_img_entry['img_shape'] = img_stroke_shape
      stroke_img_entry['num_img'] = img_stroke_shape[0]
      stroke_img_entry['data_type'] = tf.uint8
      stroke_img_entry['post_processing_func'] = stroke_img_func
      stroke_img_entry['ops'] = {}
      stroke_img_entry['ops']['stroke_output'] = self.output_sample
      if self.is_sampling is False:
        stroke_img_entry['ops']['stroke_input'] = self.input_sample

      norm_plot_img_entry = {}
      norm_plot_img_entry['img_shape'] = img_norm_shape
      norm_plot_img_entry['num_img'] = img_norm_shape[0]
      norm_plot_img_entry['data_type'] = tf.uint8
      norm_plot_img_entry['post_processing_func'] = norm_plot_img_func
      norm_plot_img_entry['ops'] = {}
      norm_plot_img_entry['ops']['norm_q_mu'] = self.norm_q_mu
      norm_plot_img_entry['ops']['norm_p_mu'] = self.norm_p_mu

      self.img_summary_entries.append(stroke_img_entry)
      self.img_summary_entries.append(norm_plot_img_entry)
      # Graph nodes to be evaluated by calling session.run
      self.ops_img_summary = {}
      # Create placeholders and containers for intermediate results.
      self.container_img = {}
      self.container_img_placeholders = {}
      self.container_img_feed_dict = {}

      for summary_dict in self.img_summary_entries:
        for op_name, summary_op in summary_dict['ops'].items():
          self.ops_img_summary[op_name] = summary_op
          # To store images.
          self.container_img[op_name] = np.zeros(summary_dict['img_shape'])
          # To pass images to summary
          self.container_img_placeholders[op_name] = tf.placeholder(
              summary_dict['data_type'], summary_dict['img_shape'])
          # Summary.
          tf.summary.image(
              op_name,
              self.container_img_placeholders[op_name],
              collections=[self.mode + '_summary_img'],
              max_outputs=summary_dict['num_img'])
          # Feed dictionary.
          self.container_img_feed_dict[
              self.container_img_placeholders[op_name]] = 0

      self.img_summary = tf.summary.merge_all(self.mode + '_summary_img')

  def get_image_summary(self,
                        session,
                        ops_img_summary_evaluated=None,
                        seq_len=500):
    """
        Evaluates the model, creates output images, plots and prepares a summary
        entry.
        Args:
            ops_img_summary_evaluated: list of summary inputs. If None passed,
              then the model is assumed to be in `sampling` mode.
            seq_len: length of a synthetic sample.

        Returns:
            summary entry for summary_writer.
        """
    if self.use_img_summary:
      if ops_img_summary_evaluated is None:  # Inference mode.
        ops_img_summary_evaluated = self.sample_unbiased(
            session, seq_len=seq_len, ops_eval=self.ops_img_summary)[0]

      # Create images.
      for summary_dict in self.img_summary_entries:
        post_processing_func = summary_dict['post_processing_func']
        for op_name, summary_op in summary_dict['ops'].items():
          for i in range(summary_dict['num_img']):
            self.container_img[op_name][i] = np.float32(
                post_processing_func(ops_img_summary_evaluated[op_name][i]))
          self.container_img_feed_dict[self.container_img_placeholders[
              op_name]] = self.container_img[op_name]

      img_summary = session.run(self.img_summary, self.container_img_feed_dict)

      return img_summary
    else:
      return None


class HandwritingVRNNGmmModel(VRNNGMM, HandwritingVRNNModel):

  def __init__(self,
               config,
               input_op,
               input_seq_length_op,
               target_op,
               input_dims,
               target_dims,
               reuse,
               data_processor,
               batch_size=-1,
               mode='training'):
    VRNNGMM.__init__(
        self,
        config,
        input_op,
        input_seq_length_op,
        target_op,
        input_dims,
        target_dims,
        reuse,
        batch_size=batch_size,
        mode=mode)

    self.dataset_obj = data_processor
    self.text_to_label_fn = data_processor.text_to_one_hot

    self.pen_loss_weight = self.config.get('loss_weights',
                                           {}).get('pen_loss', 1)
    self.eoc_loss_weight = self.config.get('loss_weights',
                                           {}).get('eoc_loss', 1)
    self.bow_loss_weight = self.config.get('loss_weights',
                                           {}).get('bow_loss', None)

    self.use_bow_loss = False if self.bow_loss_weight is None else True
    self.use_bow_labels = config.get('use_bow_labels', True)

    # TODO: Create a dictionary just for cell arguments.
    self.vrnn_cell_args = config
    self.vrnn_cell_args['input_dims'] = self.input_dims

    if target_op is not None or self.is_training or self.is_validation:
      self.target_pieces = tf.split(self.targets, target_dims, axis=2)
      # TODO Swap pen and char targets. Parent `VRNNGMM` class expects class labels as the second entry.
      tmp_targets_pen = self.target_pieces[1]
      self.target_pieces[1] = self.target_pieces[2]
      self.target_pieces[2] = tmp_targets_pen

    # See `create_image_summary` method for details.
    self.img_summary_entries = []
    self.ops_img_summary = {}
    self.use_img_summary = False

  def get_constructors(self):
    self.vrnn_cell_constructor = HandWritingVRNNGmmCell

  def build_predictions_layer(self):
    # Assign rnn outputs.
    if self.use_temporal_latent_space and self.use_variational_pi:
      self.q_mu, self.q_sigma, self.p_mu, self.p_sigma, self.gmm_z, self.q_pi, self.p_pi, self.out_mu, self.out_sigma, self.out_rho, self.out_pen, self.out_eoc = self.outputs
    elif self.use_temporal_latent_space:
      self.q_mu, self.q_sigma, self.p_mu, self.p_sigma, self.gmm_z, self.q_pi, self.out_mu, self.out_sigma, self.out_rho, self.out_pen, self.out_eoc = self.outputs
    elif self.use_variational_pi:
      self.gmm_z, self.q_pi, self.p_pi, self.out_mu, self.out_sigma, self.out_rho, self.out_pen, self.out_eoc = self.outputs

    # TODO: Sampling option.
    self.output_sample = tf.concat(
        [self.out_mu, tf.round(self.out_pen)], axis=2)
    self.input_sample = self.inputs
    self.output_dim = self.output_sample.shape.as_list()[-1]

    # For analysis.
    self.norm_p_mu = tf.norm(self.p_mu, axis=-1)
    self.norm_p_sigma = tf.norm(self.p_sigma, axis=-1)
    self.norm_q_mu = tf.norm(self.q_mu, axis=-1)
    self.norm_q_sigma = tf.norm(self.q_sigma, axis=-1)
    self.norm_out_mu = tf.norm(self.out_mu, axis=-1)
    self.norm_out_sigma = tf.norm(self.out_sigma, axis=-1)

    self.ops_evaluation['output_sample'] = self.output_sample
    if self.use_temporal_latent_space:
      self.ops_evaluation['p_mu'] = self.p_mu
      self.ops_evaluation['p_sigma'] = self.p_sigma
      self.ops_evaluation['q_mu'] = self.q_mu
      self.ops_evaluation['q_sigma'] = self.q_sigma
    if self.use_variational_pi:
      self.ops_evaluation['p_pi'] = tf.nn.softmax(self.p_pi, axis=-1)
    self.ops_evaluation['q_pi'] = tf.nn.softmax(self.q_pi, axis=-1)

    self.ops_evaluation['gmm_z'] = self.gmm_z
    self.ops_evaluation['state'] = self.output_state
    self.ops_evaluation['out_eoc'] = self.out_eoc

    # In case we want to draw samples from output distribution instead of using mean.
    self.ops_evaluation['out_mu'] = self.out_mu
    self.ops_evaluation['out_sigma'] = self.out_sigma
    self.ops_evaluation['out_rho'] = self.out_rho
    self.ops_evaluation['out_pen'] = self.out_pen

    # Visualize average gmm sigma values.
    if self.is_gmm_active:
      self.ops_scalar_summary['mean_gmm_sigma'] = tf.reduce_mean(self.gmm_sigma)

    # Sequence mask for precise loss calculation.
    self.seq_loss_mask = tf.expand_dims(
        tf.sequence_mask(
            lengths=self.input_seq_length,
            maxlen=tf.reduce_max(self.input_seq_length),
            dtype=tf.float32), -1)

  def build_loss(self):
    if self.is_training or self.is_validation:
      targets_mu = self.target_pieces[0]
      targets_pen = self.target_pieces[2]
      targets_eoc = self.target_pieces[3]

      if self.reconstruction_loss_key not in self.ops_loss:
        with tf.name_scope('reconstruction_loss'):
          # Gaussian log likelihood loss (bivariate)
          if self.reconstruction_loss == 'nll_normal_bi':
            self.ops_loss[
                self.
                reconstruction_loss_key] = -self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * logli_normal_bivariate(
                        targets_mu,
                        self.out_mu,
                        self.out_sigma,
                        self.out_rho,
                        reduce_sum=False))
          # Gaussian log likelihood loss (diagonal covariance)
          elif self.reconstruction_loss == 'nll_normal_diag':
            self.ops_loss[
                self.
                reconstruction_loss_key] = -self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * logli_normal_diag_cov(
                        targets_mu,
                        self.out_mu,
                        self.out_sigma,
                        reduce_sum=False))
          # L1 norm.
          elif self.reconstruction_loss == 'l1':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.absolute_difference(
                        targets_mu, self.out_mu, reduction='none'))
          # Mean-squared error.
          elif self.reconstruction_loss == 'mse':
            self.ops_loss[
                self.
                reconstruction_loss_key] = self.reconstruction_loss_weight * self.reduce_loss_func(
                    self.seq_loss_mask * tf.losses.mean_squared_error(
                        targets_mu, self.out_mu, reduction='none'))
          else:
            raise Exception('Undefined loss.')

      if 'loss_pen' not in self.ops_loss:
        with tf.name_scope('pen_reconstruction_loss'):
          # Bernoulli loss for pen information.
          self.ops_loss[
              'loss_pen'] = -self.pen_loss_weight * self.reduce_loss_func(
                  self.seq_loss_mask *
                  logli_bernoulli(targets_pen, self.out_pen, reduce_sum=False))

      if 'loss_eoc' not in self.ops_loss:
        with tf.name_scope('eoc_loss'):
          self.ops_loss[
              'loss_eoc'] = -self.eoc_loss_weight * self.reduce_loss_func(
                  self.seq_loss_mask *
                  logli_bernoulli(targets_eoc, self.out_eoc, reduce_sum=False))

      VRNNGMM.build_loss(self)

  ########################################
  # Evaluation methods.
  ########################################
  def sample_func(self,
                  session,
                  seq_len,
                  prev_state,
                  ops_eval,
                  text,
                  eoc_threshold,
                  cursive_threshold,
                  use_sample_mean=True):
    """
        Sampling function generating new samples randomly by sampling one stroke
        at a time.
        Args:
            session:
            seq_len: # of frames.
            ops_eval: ops to be evaluated by the model.

        Returns:
        """
    cursive_style = False
    if cursive_threshold > 0.5:
      cursive_style = True

    if ops_eval is None:
      ops_eval = self.ops_evaluation
    # These ops are required by the sampling function.
    if 'out_eoc' not in ops_eval:
      ops_eval['out_eoc'] = self.out_eoc
    if 'output_sample' not in ops_eval:
      ops_eval['output_sample'] = self.output_sample
    if 'state' not in ops_eval:
      ops_eval['state'] = self.output_state

    # Since we draw one sample at a time, we need to accumulate the results.
    output_container = {}
    for key, val in ops_eval.items():
      output_container[key] = []

    def one_step(feed_dict, save=True):
      eval_results = session.run(ops_eval, feed_dict)

      if save or (eval_results['output_sample'][0, 0, 2] == 1):
        for key in output_container.keys():
          output_container[key].append(eval_results[key])

        if use_sample_mean is False:
          sigma1, sigma2 = np.square(eval_results['out_sigma'][0, 0])
          correlation = eval_results['out_rho'][0, 0, 0] * sigma1 * sigma2
          cov_matrix = [[sigma1, correlation], [correlation, sigma2]]
          stroke_sample = np.reshape(
              np.random.multivariate_normal(eval_results['out_mu'][0][0],
                                            cov_matrix), (1, 1, -1))
          output_container['output_sample'][-1] = np.concatenate(
              [stroke_sample, np.round(eval_results['out_pen'])], axis=-1)

      return eval_results['out_eoc'], eval_results[
          'output_sample'], eval_results['state']

    use_bow_labels = self.use_bow_labels

    def prepare_model_input(char_label, bow_label):
      if use_bow_labels:
        return np.concatenate([np.zeros((1, 1, 3)), char_label, bow_label],
                              axis=-1)
      else:
        return np.concatenate([np.zeros((1, 1, 3)), char_label], axis=-1)

    zero_char_label = np.zeros((1, 1, 70))
    bow_label = np.ones((1, 1, 1))
    non_bow_label = np.zeros((1, 1, 1))

    words = text.split(' ')

    prev_eoc_step = 0
    step = 0
    for word_idx, word in enumerate(words):
      char_idx = 0

      text_char_labels = np.reshape(
          self.text_to_label_fn(list(word)), (len(word), 1, 1, -1))
      char_label = zero_char_label

      prev_x = prepare_model_input(char_label, bow_label)

      last_step = False
      while step < seq_len:
        if last_step:
          break
        step += 1
        feed = {
            self.inputs: prev_x,
            self.input_seq_length: np.ones(1),
            self.initial_state: prev_state
        }

        eoc, output_stroke, prev_state = one_step(feed_dict=feed)

        if np.squeeze(eoc) > eoc_threshold and (step - prev_eoc_step) > 4:
          prev_eoc_step = step

          char_idx += 1
          if char_idx == len(word):
            last_step = True
            char_idx -= 1

          if last_step or (not cursive_style):
            # Peek one step ahead with blank step.
            prev_x = prepare_model_input(zero_char_label, non_bow_label)

            step += 1
            feed = {
                self.inputs: prev_x,
                self.input_seq_length: np.ones(1),
                self.initial_state: prev_state
            }

            eoc, output_stroke, prev_state = one_step(
                feed_dict=feed, save=last_step)

        prev_x = prepare_model_input(text_char_labels[char_idx], non_bow_label)
    # Concatenate output lists.
    for key, val in output_container.items():
      output_container[key] = np.concatenate(val, axis=1)

    return output_container

  def sample_biased(self,
                    session,
                    seq_len,
                    prev_state,
                    prev_sample=None,
                    ops_eval=None,
                    **kwargs):
    """
        Args:
            session:
            seq_len:
            prev_state:
            prev_sample:
            ops_eval:
            **kwargs:

        Returns:
        """

    text = kwargs.get('conditional_inputs', 'test, Function. Example')
    eoc_threshold = kwargs.get('eoc_threshold', 0.15)
    cursive_threshold = kwargs.get('cursive_threshold', 0.10)
    use_sample_mean = kwargs.get('use_sample_mean', True)

    ref_len = 0
    if prev_sample is not None:
      prev_sample = np.expand_dims(
          prev_sample, axis=0) if prev_sample.ndim == 2 else prev_sample
      ref_len = prev_sample.shape[1]
    seq_len = seq_len - ref_len

    output_container = self.sample_func(session, seq_len, prev_state, ops_eval,
                                        text, eoc_threshold, cursive_threshold,
                                        use_sample_mean)

    if prev_sample is not None:
      last_prev_sample_step = np.expand_dims(
          prev_sample[:, -1, :].copy(), axis=0)
      last_prev_sample_step[0, 0, 2] = 1.0
      output_container['output_sample'][
          0, 0, 0] = output_container['output_sample'][0, 0, 0] + 20
      output_container['output_sample'] = np.concatenate(
          (prev_sample, last_prev_sample_step,
           output_container['output_sample']),
          axis=1)

    return [output_container]

  def sample_unbiased(self, session, seq_len=500, ops_eval=None, **kwargs):
    """
        Args:
            session:
            seq_len:
            ops_eval:
            **kwargs:

        Returns:
        """
    text = kwargs.get('conditional_inputs', 'test, Function. Example')
    eoc_threshold = kwargs.get('eoc_threshold', 0.15)
    cursive_threshold = kwargs.get('cursive_threshold', 0.10)
    use_sample_mean = kwargs.get('use_sample_mean', True)

    prev_state = session.run(
        self.cell.zero_state(batch_size=1, dtype=tf.float32))
    output_container = self.sample_func(session, seq_len, prev_state, ops_eval,
                                        text, eoc_threshold, cursive_threshold,
                                        use_sample_mean)

    return [output_container]


def load_models(model_dir, validation_dataset):
  config = json.load(open(os.path.join(model_dir, 'config.json'), 'r'))

  tf.reset_default_graph()
  Model_cls = HandwritingVRNNGmmModel

  batch_size = 1
  input_dims = [3, 70, 1]
  target_dims = [2, 1, 70, 1, 1]
  data_sequence_length = None

  tf.disable_eager_execution()

  strokes = tf.placeholder(
      tf.float32,
      shape=[
          batch_size, data_sequence_length,
          sum(validation_dataset.input_dims)
      ])
  targets = tf.placeholder(
      tf.float32,
      shape=[
          batch_size, data_sequence_length,
          sum(validation_dataset.target_dims)
      ])
  sequence_length = tf.placeholder(tf.int32, shape=[batch_size])

  # Create inference graph.
  with tf.name_scope('validation'):
    inference_model = Model_cls(
        config,
        reuse=False,
        input_op=strokes,
        target_op=targets,
        input_seq_length_op=sequence_length,
        input_dims=validation_dataset.input_dims,
        target_dims=validation_dataset.target_dims,
        batch_size=batch_size,
        mode='validation',
        data_processor=validation_dataset)
    inference_model.build_graph()

  # Create sampling graph.
  with tf.name_scope('sampling'):
    model = Model_cls(
        config,
        reuse=True,
        input_op=strokes,
        target_op=None,
        input_seq_length_op=sequence_length,
        input_dims=input_dims,
        target_dims=target_dims,
        batch_size=batch_size,
        mode='sampling',
        data_processor=validation_dataset)
    model.build_graph()

  # Create a session object and initialize parameters.
  sess = tf.Session()
  # Restore computation graph.
  try:
    saver = tf.train.Saver()
    checkpoint_path = tf.train.latest_checkpoint(model_dir)

    print('Loading model ' + checkpoint_path)
    saver.restore(sess, checkpoint_path)
  except:
    raise Exception('Model is not found.')

  return model, inference_model, sess


def run_model(model, inference_model, sess, validation_dataset,
              valid_data_sample, new_text, similarity):
  original_sample = valid_data_sample[3]
  stroke_model_input = valid_data_sample[2]

  #  Make sure we only cut at the border of the stroke.
  totlen = len(stroke_model_input)
  ends = np.where(original_sample[:, 2] == 1.)[0]
  cutoff = int(totlen * similarity)
  while len(ends) and ends[0] < cutoff:
    ends = ends[1:]

  if similarity == 0.0:
    totlen = 0
  elif len(ends):
    totlen = ends[0] + 1
  else:
    totlen = len(original_sample)
  stroke_model_input = stroke_model_input[:totlen]

  orig_text = valid_data_sample[4]
  inference_results = inference_model.reconstruct_given_sample(
      session=sess, inputs=stroke_model_input)
  keyword_args = {}
  text = new_text if new_text else orig_text
  keyword_args['conditional_inputs'] = text
  keyword_args['use_sample_mean'] = True
  reference_sample_in_img = None
  biased_sampling_results = model.sample_biased(
      session=sess,
      seq_len=600,
      prev_state=inference_results[0]['state'],
      prev_sample=reference_sample_in_img,
      **keyword_args)

  synthetic_sample = biased_sampling_results[0]['output_sample'][0]
  reconstructed_sample = validation_dataset.undo_preprocess(synthetic_sample)
  if reconstructed_sample[0, 2] == 1:
    reconstructed_sample = reconstructed_sample[1:]
  return reconstructed_sample
