# Copyright 2017 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""LSTM based modules for TensorFlow snt.

This python module contains LSTM-like cores that fall under the broader group
of RNN cores.  In general, initializers for the gate weights and other
model parameters may be passed to the constructor.

Typical usage example of the standard LSTM without peephole connections:

  ```
  import sonnet as snt

  hidden_size = 10
  batch_size = 2

  # Simple LSTM op on some input
  rnn = snt.LSTM(hidden_size)
  input = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
  out, next_state = rnn(input, rnn.initial_state(batch_size))
  ```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

# Dependency imports

from six.moves import xrange  # pylint: disable=redefined-builtin
from sonnet.python.modules import base
from sonnet.python.modules import basic
from sonnet.python.modules import batch_norm
from sonnet.python.modules import conv
from sonnet.python.modules import layer_norm
from sonnet.python.modules import rnn_core
from sonnet.python.modules import util
import tensorflow as tf

from tensorflow.python.ops import array_ops


LSTMState = collections.namedtuple("LSTMState", ("hidden", "cell"))


class LSTM(rnn_core.RNNCore):
  """LSTM recurrent network cell with optional peepholes & layer normalization.

  The implementation is based on: http://arxiv.org/abs/1409.2329. We add
  forget_bias (default: 1) to the biases of the forget gate in order to
  reduce the scale of forgetting in the beginning of the training.

  #### Layer normalization

  This is described in https://arxiv.org/pdf/1607.06450.pdf

  #### Peep-hole connections

  Peep-hole connections may optionally be used by specifying a flag in the
  constructor. These connections can aid increasing the precision of output
  timing, for more details see:

    https://research.google.com/pubs/archive/43905.pdf

  #### Recurrent projections

  Projection of the recurrent state, to reduce model parameters and speed up
  computation. For more details see:

    https://arxiv.org/abs/1402.1128

  Attributes:
    state_size: Tuple of `tf.TensorShape`s indicating the size of state tensors.
    output_size: `tf.TensorShape` indicating the size of the core output.
    use_peepholes: Boolean indicating whether peephole connections are used.
  """
  # Keys that may be provided for parameter initializers.
  W_GATES = "w_gates"  # weight for gates
  B_GATES = "b_gates"  # bias of gates
  W_F_DIAG = "w_f_diag"  # weight for prev_cell -> forget gate peephole
  W_I_DIAG = "w_i_diag"  # weight for prev_cell -> input gate peephole
  W_O_DIAG = "w_o_diag"  # weight for prev_cell -> output gate peephole
  W_H_PROJECTION = "w_h_projection"  # weight for (opt) projection of h in state
  POSSIBLE_INITIALIZER_KEYS = {
      W_GATES, B_GATES, W_F_DIAG, W_I_DIAG, W_O_DIAG, W_H_PROJECTION}

  def __init__(self,
               hidden_size,
               forget_bias=1.0,
               initializers=None,
               partitioners=None,
               regularizers=None,
               use_peepholes=False,
               use_layer_norm=False,
               hidden_clip_value=None,
               projection_size=None,
               cell_clip_value=None,
               custom_getter=None,
               name="lstm", bayesian_variables=None):
    """Construct LSTM.

    Args:
      hidden_size: (int) Hidden size dimensionality.
      forget_bias: (float) Bias for the forget activation.
      initializers: Dict containing ops to initialize the weights.
        This dictionary may contain any of the keys returned by
        `LSTM.get_possible_initializer_keys`.
      partitioners: Optional dict containing partitioners to partition
        the weights and biases. As a default, no partitioners are used. This
        dict may contain any of the keys returned by
        `LSTM.get_possible_initializer_keys`.
      regularizers: Optional dict containing regularizers for the weights and
        biases. As a default, no regularizers are used. This dict may contain
        any of the keys returned by
        `LSTM.get_possible_initializer_keys`.
      use_peepholes: Boolean that indicates whether peephole connections are
        used.
      use_layer_norm: Boolean that indicates whether to apply layer
        normalization.
      hidden_clip_value: Optional number; if set, then the LSTM hidden state
        vector is clipped by this value.
      projection_size: Optional number; if set, then the LSTM hidden state is
        projected to this size via a learnable projection matrix.
      cell_clip_value: Optional number; if set, then the LSTM cell vector is
        clipped by this value.
      custom_getter: Callable that takes as a first argument the true getter,
        and allows overwriting the internal get_variable method. See the
        `tf.get_variable` documentation for more details.
      name: Name of the module.

    Raises:
      KeyError: if `initializers` contains any keys not returned by
        `LSTM.get_possible_initializer_keys`.
      KeyError: if `partitioners` contains any keys not returned by
        `LSTM.get_possible_initializer_keys`.
      KeyError: if `regularizers` contains any keys not returned by
        `LSTM.get_possible_initializer_keys`.
      ValueError: if a peephole initializer is passed in the initializer list,
        but `use_peepholes` is False.
    """
    super(LSTM, self).__init__(custom_getter=custom_getter, name=name)

    self._hidden_size = hidden_size
    self._forget_bias = forget_bias
    self._use_peepholes = use_peepholes
    self._use_layer_norm = use_layer_norm
    self._hidden_clip_value = hidden_clip_value
    self._cell_clip_value = cell_clip_value
    self._use_projection = projection_size is not None
    self._hidden_state_size = projection_size or hidden_size

    self.possible_keys = self.get_possible_initializer_keys(
        use_peepholes=use_peepholes, use_projection=self._use_projection)
    self._initializers = util.check_initializers(initializers,
                                                 self.possible_keys)
    self._partitioners = util.check_initializers(partitioners,
                                                 self.possible_keys)
    self._regularizers = util.check_initializers(regularizers,
                                                 self.possible_keys)
    if hidden_clip_value is not None and hidden_clip_value < 0:
      raise ValueError("The value of hidden_clip_value should be nonnegative.")
    if cell_clip_value is not None and cell_clip_value < 0:
      raise ValueError("The value of cell_clip_value should be nonnegative.")

    #YC modified:
    self.bayesian_variables= bayesian_variables

  @classmethod
  def get_possible_initializer_keys(cls, use_peepholes=False,
                                    use_projection=False):
    """Returns the keys the dictionary of variable initializers may contain.

    The set of all possible initializer keys are:
      w_gates:  weight for gates
      b_gates:  bias of gates
      w_f_diag: weight for prev_cell -> forget gate peephole
      w_i_diag: weight for prev_cell -> input gate peephole
      w_o_diag: weight for prev_cell -> output gate peephole

    Args:
      cls:The class.
      use_peepholes: Boolean that indicates whether peephole connections are
        used.
      use_projection: Boolean that indicates whether a recurrent projection
        layer is used.

    Returns:
      Set with strings corresponding to the strings that may be passed to the
        constructor.
    """

    possible_keys = cls.POSSIBLE_INITIALIZER_KEYS.copy()
    if not use_peepholes:
      possible_keys.difference_update(
          {cls.W_F_DIAG, cls.W_I_DIAG, cls.W_O_DIAG})
    if not use_projection:
      possible_keys.difference_update({cls.W_H_PROJECTION})
    return possible_keys

  def _build(self, inputs, prev_state):
    """Connects the LSTM module into the graph.

    If this is not the first time the module has been connected to the graph,
    the Tensors provided as inputs and state must have the same final
    dimension, in order for the existing variables to be the correct size for
    their corresponding multiplications. The batch size may differ for each
    connection.

    Args:
      inputs: Tensor of size `[batch_size, input_size]`.
      prev_state: Tuple (prev_hidden, prev_cell).

    Returns:
      A tuple (output, next_state) where 'output' is a Tensor of size
      `[batch_size, hidden_size]` and 'next_state' is a `LSTMState` namedtuple
      (next_hidden, next_cell) where `next_hidden` and `next_cell` have size
      `[batch_size, hidden_size]`. If `projection_size` is specified, then
      `next_hidden` will have size `[batch_size, projection_size]`.
    Raises:
      ValueError: If connecting the module into the graph any time after the
        first time, and the inferred size of the inputs does not match previous
        invocations.
    """
    prev_hidden, prev_cell = prev_state

    # pylint: disable=invalid-unary-operand-type
    if self._hidden_clip_value is not None:
      prev_hidden = tf.clip_by_value(
          prev_hidden, -self._hidden_clip_value, self._hidden_clip_value)
    if self._cell_clip_value is not None:
      prev_cell = tf.clip_by_value(
          prev_cell, -self._cell_clip_value, self._cell_clip_value)
    # pylint: enable=invalid-unary-operand-type

    #print (inputs, "gru rnn inputttttttttttttt")
    self._create_gate_variables(inputs.get_shape(), inputs.dtype)

    # pylint false positive: calling module of same file;
    # pylint: disable=not-callable

    # Parameters of gates are concatenated into one multiply for efficiency.
    inputs_and_hidden = tf.concat([inputs, prev_hidden], 1)
    gates = tf.matmul(inputs_and_hidden, self._w_xh)

    if self._use_layer_norm:
      gates = layer_norm.LayerNorm()(gates)

    gates += self._b

    # i = input_gate, j = next_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(value=gates, num_or_size_splits=4, axis=1)

    if self._use_peepholes:  # diagonal connections
      self._create_peephole_variables(inputs.dtype)
      f += self._w_f_diag * prev_cell
      i += self._w_i_diag * prev_cell

    forget_mask = tf.sigmoid(f + self._forget_bias)
    next_cell = forget_mask * prev_cell + tf.sigmoid(i) * tf.tanh(j)
    cell_output = next_cell
    if self._use_peepholes:
      cell_output += self._w_o_diag * cell_output
    next_hidden = tf.tanh(cell_output) * tf.sigmoid(o)

    if self._use_projection:
      next_hidden = tf.matmul(next_hidden, self._w_h_projection)

    return next_hidden, LSTMState(hidden=next_hidden, cell=next_cell)

  def _create_gate_variables(self, input_shape, dtype):
    """Initialize the variables used for the gates."""
    if len(input_shape) != 2:
      raise ValueError(
          "Rank of shape must be {} not: {}".format(2, len(input_shape)))

    equiv_input_size = self._hidden_state_size + input_shape.dims[1].value
    initializer = basic.create_linear_initializer(equiv_input_size)

    # YC modified:
    if self.bayesian_variables ==None:
      self._w_xh = tf.get_variable(
          self.W_GATES,
          shape=[equiv_input_size, 4 * self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.W_GATES, initializer),
          partitioner=self._partitioners.get(self.W_GATES),
          regularizer=self._regularizers.get(self.W_GATES))
      print (equiv_input_size, self._hidden_state_size, input_shape,self._w_xh, 'wocannnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn')
      print (initializer, self._initializers.get(self.W_GATES, initializer), self._partitioners.get(self.W_GATES), self._regularizers.get(self.W_GATES))
      
      self._b = tf.get_variable(
          self.B_GATES,
          shape=[4 * self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.B_GATES, initializer),
          partitioner=self._partitioners.get(self.B_GATES),
          regularizer=self._regularizers.get(self.B_GATES))
    # YC's modification for Bayesian LSTM:
    else:
      self._w_xh=self.bayesian_variables[0]
      self._b =self.bayesian_variables[1]


    if self._use_projection:
      w_h_initializer = basic.create_linear_initializer(self._hidden_size)
      self._w_h_projection = tf.get_variable(
          self.W_H_PROJECTION,
          shape=[self._hidden_size, self._hidden_state_size],
          dtype=dtype,
          initializer=self._initializers.get(self.W_H_PROJECTION,
                                             w_h_initializer),
          partitioner=self._partitioners.get(self.W_H_PROJECTION),
          regularizer=self._regularizers.get(self.W_H_PROJECTION))

  def _create_peephole_variables(self, dtype):
    """Initialize the variables used for the peephole connections."""
    self._w_f_diag = tf.get_variable(
        self.W_F_DIAG,
        shape=[self._hidden_size],
        dtype=dtype,
        initializer=self._initializers.get(self.W_F_DIAG),
        partitioner=self._partitioners.get(self.W_F_DIAG),
        regularizer=self._regularizers.get(self.W_F_DIAG))
    self._w_i_diag = tf.get_variable(
        self.W_I_DIAG,
        shape=[self._hidden_size],
        dtype=dtype,
        initializer=self._initializers.get(self.W_I_DIAG),
        partitioner=self._partitioners.get(self.W_I_DIAG),
        regularizer=self._regularizers.get(self.W_I_DIAG))
    self._w_o_diag = tf.get_variable(
        self.W_O_DIAG,
        shape=[self._hidden_size],
        dtype=dtype,
        initializer=self._initializers.get(self.W_O_DIAG),
        partitioner=self._partitioners.get(self.W_O_DIAG),
        regularizer=self._regularizers.get(self.W_O_DIAG))

  @property
  def state_size(self):
    """Tuple of `tf.TensorShape`s indicating the size of state tensors."""
    return LSTMState(tf.TensorShape([self._hidden_state_size]),
                     tf.TensorShape([self._hidden_size]))

  @property
  def output_size(self):
    """`tf.TensorShape` indicating the size of the core output."""
    return tf.TensorShape([self._hidden_state_size])

  @property
  def use_peepholes(self):
    """Boolean indicating whether peephole connections are used."""
    return self._use_peepholes

  @property
  def use_layer_norm(self):
    """Boolean indicating whether layer norm is enabled."""
    return self._use_layer_norm


class RecurrentDropoutWrapper(rnn_core.RNNCore):
  """Wraps an RNNCore so that recurrent dropout can be applied."""

  def __init__(self, core, keep_probs):
    """Builds a new wrapper around a given core.

    Args:
      core: the RNN core to be wrapped.
      keep_probs: the recurrent dropout keep probabilities to apply.
        This should have the same structure has core.init_state. No dropout is
        applied for leafs set to None.
    """

    super(RecurrentDropoutWrapper, self).__init__(
        custom_getter=None, name=core.module_name + "_recdropout")
    self._core = core
    self._keep_probs = keep_probs

    # self._dropout_state_size is a list of shape for the state parts to which
    # dropout is to be applied.
    # self._dropout_index has the same shape as the core state. Leafs contain
    # either None if no dropout is applied or an integer representing an index
    # in self._dropout_state_size.
    self._dropout_state_size = []

    def set_dropout_state_size(keep_prob, state_size):
      if keep_prob is not None:
        self._dropout_state_size.append(state_size)
        return len(self._dropout_state_size) - 1
      return None

    self._dropout_indexes = tf.contrib.framework.nest.map_structure(
        set_dropout_state_size, keep_probs, core.state_size)

  def _build(self, inputs, prev_state):
    core_state, dropout_masks = prev_state
    output, next_core_state = self._core(inputs, core_state)

    # Dropout masks are generated via tf.nn.dropout so they actually include
    # rescaling: the mask value is 1/keep_prob if no dropout is applied.
    next_core_state = tf.contrib.framework.nest.map_structure(
        lambda i, state: state if i is None else state * dropout_masks[i],
        self._dropout_indexes, next_core_state)

    return output, (next_core_state, dropout_masks)

  def initial_state(self, batch_size, dtype=tf.float32, trainable=False,
                    trainable_initializers=None, trainable_regularizers=None,
                    name=None):
    """Builds the default start state tensor of zeros."""
    core_initial_state = self._core.initial_state(
        batch_size, dtype=dtype, trainable=trainable,
        trainable_initializers=trainable_initializers,
        trainable_regularizers=trainable_regularizers, name=name)

    dropout_masks = [None] * len(self._dropout_state_size)
    def set_dropout_mask(index, state, keep_prob):
      if index is not None:
        ones = tf.ones_like(state, dtype=dtype)
        dropout_masks[index] = tf.nn.dropout(ones, keep_prob=keep_prob)
    tf.contrib.framework.nest.map_structure(
        set_dropout_mask,
        self._dropout_indexes, core_initial_state, self._keep_probs)

    return core_initial_state, dropout_masks

  @property
  def state_size(self):
    return self._core.state_size, self._dropout_state_size

  @property
  def output_size(self):
    return self._core.output_size


def lstm_with_recurrent_dropout(hidden_size, keep_prob=0.5, **kwargs):
  """LSTM with recurrent dropout.

  Args:
    hidden_size: the LSTM hidden size.
    keep_prob: the probability to keep an entry when applying dropout.
    **kwargs: Extra keyword arguments to pass to the LSTM.

  Returns:
    A tuple (train_lstm, test_lstm) where train_lstm is an LSTM with
    recurrent dropout enabled to be used for training and test_lstm
    is the same LSTM without recurrent dropout.
  """

  lstm = LSTM(hidden_size, **kwargs)
  return RecurrentDropoutWrapper(lstm, LSTMState(keep_prob, None)), lstm


class ZoneoutWrapper(rnn_core.RNNCore):
  """Wraps an RNNCore so that zoneout can be applied.

  Zoneout was introduced in https://arxiv.org/abs/1606.01305
  It consists of randomly freezing some RNN state in the same way recurrent
  dropout would replace this state with zero.
  """

  def __init__(self, core, keep_probs, is_training):
    """Builds a new wrapper around a given core.

    Args:
      core: the RNN core to be wrapped.
      keep_probs: the probabilities to use the updated states rather than
        keeping the old state values. This is one minus the probability
        that zoneout gets applied.
        This should have the same structure has core.init_state. No zoneout is
        applied for leafs set to None.
      is_training: when set, apply some stochastic zoneout. Otherwise perform
        a linear combination of the previous state and the current state based
        on the zoneout probability.
    """

    super(ZoneoutWrapper, self).__init__(
        custom_getter=None, name=core.module_name + "_zoneout")
    self._core = core
    self._keep_probs = keep_probs
    self._is_training = is_training

  def _build(self, inputs, prev_state):
    output, next_state = self._core(inputs, prev_state)

    def apply_zoneout(keep_prob, next_s, prev_s):
      if keep_prob is None:
        return next_s
      if self._is_training:
        diff = next_s - prev_s
        # The dropout returns 0 with probability 1 - keep_prob and in this case
        # this function returns prev_s
        # It returns diff / keep_prob otherwise and then this function returns
        # prev_s + diff = next_s
        return prev_s + tf.nn.dropout(diff, keep_prob) * keep_prob
      else:
        return prev_s * (1 - keep_prob) + next_s * keep_prob

    next_state = tf.contrib.framework.nest.map_structure(
        apply_zoneout, self._keep_probs, next_state, prev_state)

    return output, next_state

  def initial_state(self, batch_size, dtype=tf.float32, trainable=False,
                    trainable_initializers=None, trainable_regularizers=None,
                    name=None):
    """Builds the default start state tensor of zeros."""
    return self._core.initial_state(
        batch_size, dtype=dtype, trainable=trainable,
        trainable_initializers=trainable_initializers,
        trainable_regularizers=trainable_regularizers, name=name)

  @property
  def state_size(self):
    return self._core.state_size

  @property
  def output_size(self):
    return self._core.output_size


def lstm_with_zoneout(hidden_size, keep_prob_c=0.5, keep_prob_h=0.95, **kwargs):
  """LSTM with recurrent dropout.

  Args:
    hidden_size: the LSTM hidden size.
    keep_prob_c: the probability to use the new value of the cell state rather
      than freezing it.
    keep_prob_h: the probability to use the new value of the hidden state
      rather than freezing it.
    **kwargs: Extra keyword arguments to pass to the LSTM.

  Returns:
    A tuple (train_lstm, test_lstm) where train_lstm is an LSTM with
    recurrent dropout enabled to be used for training and test_lstm
    is the same LSTM without zoneout.
  """

  lstm = LSTM(hidden_size, **kwargs)
  keep_probs = LSTMState(keep_prob_h, keep_prob_c)
  train_lstm = ZoneoutWrapper(lstm, keep_probs, is_training=True)
  test_lstm = ZoneoutWrapper(lstm, keep_probs, is_training=False)
  return train_lstm, test_lstm


class BatchNormLSTM(rnn_core.RNNCore):
  """LSTM recurrent network cell with optional peepholes, batch normalization.

  The base implementation is based on: http://arxiv.org/abs/1409.2329. We add
  forget_bias (default: 1) to the biases of the forget gate in order to
  reduce the scale of forgetting in the beginning of the training.

  #### Peep-hole connections

  Peep-hole connections may optionally be used by specifying a flag in the
  constructor. These connections can aid increasing the precision of output
  timing, for more details see:

    https://research.google.com/pubs/archive/43905.pdf

  #### Batch normalization

  The batch norm transformation (in training mode) is
    batchnorm(x) = gamma * (x - mean(x)) / stddev(x) + beta,
  where gamma is a learnt scaling factor and beta is a learnt offset.

  Batch normalization may optionally be used at different places in the LSTM by
  specifying flag(s) in the constructor. These are applied when calculating
  the gate activations and cell-to-hidden transformation. The set-up is based on

    https://arxiv.org/pdf/1603.09025.pdf

  ##### Batch normalization: where to apply?

  Batch norm can be applied in three different places in the LSTM:

    (h) To the W_h h_{t-1} contribution to the gates from the previous hiddens.
    (x) To the W_x x_t contribution to the gates from the current input.
    (c) To the cell value c_t when calculating the output h_t from the cell.

  (The notation here is consistent with the Recurrent Batch Normalization
  paper). Each of these can be controlled individually, because batch norm is
  expensive, and not all are necessary. The paper doesn't mention the relative
  effects of these different batch norms; however, experimentation with a
  shallow LSTM for the `permuted_mnist` sequence task suggests that (h) is the
  most important and the other two can be left off. For other tasks or deeper
  (stacked) LSTMs, other batch norm combinations may be more effective.

  ##### Batch normalization: collecting stats (training vs test)

  When switching to testing (see `LSTM.with_batch_norm_control`), we can use a
  mean and stddev learnt from the training data instead of using the statistics
  from the test data. (This both increases test accuracy because the statistics
  have less variance, and if the test data does not have the same distribution
  as the training data then we must use the training statistics to ensure the
  effective network does not change when switching to testing anyhow.)

  This does however introduces a slight subtlety. The first few time steps of
  the RNN tend to have varying statistics (mean and variance) before settling
  down to a steady value. Therefore in general, better performance is obtained
  by using separate statistics for the first few time steps, and then using the
  final set of statistics for all subsequent time steps. This is controlled by
  the parameter `max_unique_stats`. (We can't have an unbounded number of
  distinct statistics for both technical reasons and also for the case where
  test sequences are longer than anything seen in training.)

  You may be fine leaving it at its default value of 1. Small values (like 10)
  may achieve better performance on some tasks when testing with cached
  statistics.

  Attributes:
    state_size: Tuple of `tf.TensorShape`s indicating the size of state tensors.
    output_size: `tf.TensorShape` indicating the size of the core output.
    use_peepholes: Boolean indicating whether peephole connections are used.
    use_batch_norm_h: Boolean indicating whether batch norm (h) is enabled.
    use_batch_norm_x: Boolean indicating whether batch norm (x) is enabled.
    use_batch_norm_c: Boolean indicating whether batch norm (c) is enabled.
  """

  # Keys that may be provided for parameter initializers.
  W_GATES = "w_gates"  # weight for gates
  B_GATES = "b_gates"  # bias of gates
  W_F_DIAG = "w_f_diag"  # weight for prev_cell -> forget gate peephole
  W_I_DIAG = "w_i_diag"  # weight for prev_cell -> input gate peephole
  W_O_DIAG = "w_o_diag"  # weight for prev_cell -> output gate peephole
  GAMMA_H = "gamma_h"  # batch norm scaling for previous_hidden -> gates
  GAMMA_X = "gamma_x"  # batch norm scaling for input -> gates
  GAMMA_C = "gamma_c"  # batch norm scaling for cell -> output
  BETA_C = "beta_c"  # (batch norm) bias for cell -> output
  POSSIBLE_INITIALIZER_KEYS = {W_GATES, B_GATES, W_F_DIAG, W_I_DIAG, W_O_DIAG,
                               GAMMA_H, GAMMA_X, GAMMA_C, BETA_C}
  # Keep old name for backwards compatibility

  POSSIBLE_KEYS = POSSIBLE_INITIALIZER_KEYS

  def __init__(self,
               hidden_size,
               forget_bias=1.0,
               initializers=None,
               partitioners=None,
               regularizers=None,
               use_peepholes=False,
               use_batch_norm_h=True,
               use_batch_norm_x=False,
               use_batch_norm_c=False,
               max_unique_stats=1,
               hidden_clip_value=None,
               cell_clip_value=None,
               custom_getter=None,
               name="batch_norm_lstm"):
    """Construct `BatchNormLSTM`.

    Args:
      hidden_size: (int) Hidden size dimensionality.
      forget_bias: (float) Bias for the forget activation.
      initializers: Dict containing ops to initialize the weights.
        This dictionary may contain any of the keys returned by
        `BatchNormLSTM.get_possible_initializer_keys`.
        The gamma and beta variables control batch normalization values for
        different batch norm transformations inside the cell; see the paper for
        details.
      partitioners: Optional dict containing partitioners to partition
        the weights and biases. As a default, no partitioners are used. This
        dict may contain any of the keys returned by
        `BatchNormLSTM.get_possible_initializer_keys`.
      regularizers: Optional dict containing regularizers for the weights and
        biases. As a default, no regularizers are used. This dict may contain
        any of the keys returned by
        `BatchNormLSTM.get_possible_initializer_keys`.
      use_peepholes: Boolean that indicates whether peephole connections are
        used.
      use_batch_norm_h: Boolean that indicates whether to apply batch
        normalization at the previous_hidden -> gates contribution. If you are
        experimenting with batch norm then this may be the most effective to
        use, and is enabled by default.
      use_batch_norm_x: Boolean that indicates whether to apply batch
        normalization at the input -> gates contribution.
      use_batch_norm_c: Boolean that indicates whether to apply batch
        normalization at the cell -> output contribution.
      max_unique_stats: The maximum number of steps to use unique batch norm
        statistics for. (See module description above for more details.)
      hidden_clip_value: Optional number; if set, then the LSTM hidden state
        vector is clipped by this value.
      cell_clip_value: Optional number; if set, then the LSTM cell vector is
        clipped by this value.
      custom_getter: Callable that takes as a first argument the true getter,
        and allows overwriting the internal get_variable method. See the
        `tf.get_variable` documentation for more details.
      name: Name of the module.

    Raises:
      KeyError: if `initializers` contains any keys not returned by
        `BatchNormLSTM.get_possible_initializer_keys`.
      KeyError: if `partitioners` contains any keys not returned by
        `BatchNormLSTM.get_possible_initializer_keys`.
      KeyError: if `regularizers` contains any keys not returned by
        `BatchNormLSTM.get_possible_initializer_keys`.
      ValueError: if a peephole initializer is passed in the initializer list,
        but `use_peepholes` is False.
      ValueError: if a batch norm initializer is passed in the initializer list,
        but batch norm is disabled.
      ValueError: if none of the `use_batch_norm_*` options are True.
      ValueError: if `max_unique_stats` is < 1.
    """
    if not any([use_batch_norm_h, use_batch_norm_x, use_batch_norm_c]):
      raise ValueError("At least one use_batch_norm_* option is required for "
                       "BatchNormLSTM")
    super(BatchNormLSTM, self).__init__(custom_getter=custom_getter, name=name)

    self._hidden_size = hidden_size
    self._forget_bias = forget_bias
    self._use_peepholes = use_peepholes
    self._max_unique_stats = max_unique_stats
    self._use_batch_norm_h = use_batch_norm_h
    self._use_batch_norm_x = use_batch_norm_x
    self._use_batch_norm_c = use_batch_norm_c
    self._hidden_clip_value = hidden_clip_value
    self._cell_clip_value = cell_clip_value
    self.possible_keys = self.get_possible_initializer_keys(
        use_peepholes=use_peepholes, use_batch_norm_h=use_batch_norm_h,
        use_batch_norm_x=use_batch_norm_x, use_batch_norm_c=use_batch_norm_c)
    self._initializers = util.check_initializers(initializers,
                                                 self.possible_keys)
    self._partitioners = util.check_initializers(partitioners,
                                                 self.possible_keys)
    self._regularizers = util.check_initializers(regularizers,
                                                 self.possible_keys)
    if max_unique_stats < 1:
      raise ValueError("max_unique_stats must be >= 1")
    if max_unique_stats != 1 and not (
        use_batch_norm_h or use_batch_norm_x or use_batch_norm_c):
      raise ValueError("max_unique_stats specified but batch norm disabled")
    if hidden_clip_value is not None and hidden_clip_value < 0:
      raise ValueError("The value of hidden_clip_value should be nonnegative.")
    if cell_clip_value is not None and cell_clip_value < 0:
      raise ValueError("The value of cell_clip_value should be nonnegative.")

    if use_batch_norm_h:
      self._batch_norm_h = BatchNormLSTM.IndexedStatsBatchNorm(max_unique_stats,
                                                               "batch_norm_h")
    if use_batch_norm_x:
      self._batch_norm_x = BatchNormLSTM.IndexedStatsBatchNorm(max_unique_stats,
                                                               "batch_norm_x")
    if use_batch_norm_c:
      self._batch_norm_c = BatchNormLSTM.IndexedStatsBatchNorm(max_unique_stats,
                                                               "batch_norm_c")

  def with_batch_norm_control(self, is_training, test_local_stats=True):
    """Wraps this RNNCore with the additional control input to the `BatchNorm`s.

    Example usage:

      lstm = snt.BatchNormLSTM(4)
      is_training = tf.placeholder(tf.bool)
      rnn_input = ...
      my_rnn = rnn.rnn(lstm.with_batch_norm_control(is_training), rnn_input)

    Args:
      is_training: Boolean that indicates whether we are in
        training mode or testing mode. When in training mode, the batch norm
        statistics are taken from the given batch, and moving statistics are
        updated. When in testing mode, the moving statistics are not updated,
        and in addition if `test_local_stats` is False then the moving
        statistics are used for the batch statistics. See the `BatchNorm` module
        for more details.
      test_local_stats: Boolean scalar indicated whether to use local
        batch statistics in test mode.

    Returns:
      snt.RNNCore wrapping this class with the extra input(s) added.
    """
    return BatchNormLSTM.CoreWithExtraBuildArgs(
        self, is_training=is_training, test_local_stats=test_local_stats)

  @classmethod
  def get_possible_initializer_keys(
      cls, use_peepholes=False, use_batch_norm_h=True, use_batch_norm_x=False,
      use_batch_norm_c=False):
    """Returns the keys the dictionary of variable initializers may contain.

    The set of all possible initializer keys are:
      w_gates:  weight for gates
      b_gates:  bias of gates
      w_f_diag: weight for prev_cell -> forget gate peephole
      w_i_diag: weight for prev_cell -> input gate peephole
      w_o_diag: weight for prev_cell -> output gate peephole
      gamma_h:  batch norm scaling for previous_hidden -> gates
      gamma_x:  batch norm scaling for input -> gates
      gamma_c:  batch norm scaling for cell -> output
      beta_c:   batch norm bias for cell -> output

    Args:
      cls:The class.
      use_peepholes: Boolean that indicates whether peephole connections are
        used.
      use_batch_norm_h: Boolean that indicates whether to apply batch
        normalization at the previous_hidden -> gates contribution. If you are
        experimenting with batch norm then this may be the most effective to
        turn on.
      use_batch_norm_x: Boolean that indicates whether to apply batch
        normalization at the input -> gates contribution.
      use_batch_norm_c: Boolean that indicates whether to apply batch
        normalization at the cell -> output contribution.

    Returns:
      Set with strings corresponding to the strings that may be passed to the
        constructor.
    """

    possible_keys = cls.POSSIBLE_INITIALIZER_KEYS.copy()
    if not use_peepholes:
      possible_keys.difference_update(
          {cls.W_F_DIAG, cls.W_I_DIAG, cls.W_O_DIAG})
    if not use_batch_norm_h:
      possible_keys.remove(cls.GAMMA_H)
    if not use_batch_norm_x:
      possible_keys.remove(cls.GAMMA_X)
    if not use_batch_norm_c:
      possible_keys.difference_update({cls.GAMMA_C, cls.BETA_C})
    return possible_keys

  def _build(self, inputs, prev_state, is_training=None, test_local_stats=True):
    """Connects the LSTM module into the graph.

    If this is not the first time the module has been connected to the graph,
    the Tensors provided as inputs and state must have the same final
    dimension, in order for the existing variables to be the correct size for
    their corresponding multiplications. The batch size may differ for each
    connection.

    Args:
      inputs: Tensor of size `[batch_size, input_size]`.
      prev_state: Tuple (prev_hidden, prev_cell), or if batch norm is enabled
        and `max_unique_stats > 1`, then (prev_hidden, prev_cell, time_step).
        Here, prev_hidden and prev_cell are tensors of size
        `[batch_size, hidden_size]`, and time_step is used to indicate the
        current RNN step.
      is_training: Boolean indicating whether we are in training mode (as
        opposed to testing mode), passed to the batch norm
        modules. Note to use this you must wrap the cell via the
        `with_batch_norm_control` function.
      test_local_stats: Boolean indicating whether to use local batch statistics
        in test mode. See the `BatchNorm` documentation for more on this.

    Returns:
      A tuple (output, next_state) where 'output' is a Tensor of size
      `[batch_size, hidden_size]` and 'next_state' is a tuple
      (next_hidden, next_cell) or (next_hidden, next_cell, time_step + 1),
      where next_hidden and next_cell have size `[batch_size, hidden_size]`.

    Raises:
      ValueError: If connecting the module into the graph any time after the
        first time, and the inferred size of the inputs does not match previous
        invocations.
    """
    if is_training is None:
      raise ValueError("Boolean is_training flag must be explicitly specified "
                       "when using batch normalization.")

    if self._max_unique_stats == 1:
      prev_hidden, prev_cell = prev_state
      time_step = None
    else:
      prev_hidden, prev_cell, time_step = prev_state

    # pylint: disable=invalid-unary-operand-type
    if self._hidden_clip_value is not None:
      prev_hidden = tf.clip_by_value(
          prev_hidden, -self._hidden_clip_value, self._hidden_clip_value)
    if self._cell_clip_value is not None:
      prev_cell = tf.clip_by_value(
          prev_cell, -self._cell_clip_value, self._cell_clip_value)
    # pylint: enable=invalid-unary-operand-type

    self._create_gate_variables(inputs.get_shape(), inputs.dtype)
    self._create_batch_norm_variables(inputs.dtype)

    # pylint false positive: calling module of same file;
    # pylint: disable=not-callable

    if self._use_batch_norm_h or self._use_batch_norm_x:
      gates_h = tf.matmul(prev_hidden, self._w_h)
      gates_x = tf.matmul(inputs, self._w_x)
      if self._use_batch_norm_h:
        gates_h = self._gamma_h * self._batch_norm_h(gates_h,
                                                     time_step,
                                                     is_training,
                                                     test_local_stats)
      if self._use_batch_norm_x:
        gates_x = self._gamma_x * self._batch_norm_x(gates_x,
                                                     time_step,
                                                     is_training,
                                                     test_local_stats)
      gates = gates_h + gates_x
    else:
      # Parameters of gates are concatenated into one multiply for efficiency.
      inputs_and_hidden = tf.concat([inputs, prev_hidden], 1)
      gates = tf.matmul(inputs_and_hidden, self._w_xh)

    gates += self._b

    # i = input_gate, j = next_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(value=gates, num_or_size_splits=4, axis=1)

    if self._use_peepholes:  # diagonal connections
      self._create_peephole_variables(inputs.dtype)
      f += self._w_f_diag * prev_cell
      i += self._w_i_diag * prev_cell

    forget_mask = tf.sigmoid(f + self._forget_bias)
    next_cell = forget_mask * prev_cell + tf.sigmoid(i) * tf.tanh(j)
    cell_output = next_cell
    if self._use_batch_norm_c:
      cell_output = (self._beta_c
                     + self._gamma_c * self._batch_norm_c(cell_output,
                                                          time_step,
                                                          is_training,
                                                          test_local_stats))
    if self._use_peepholes:
      cell_output += self._w_o_diag * cell_output
    next_hidden = tf.tanh(cell_output) * tf.sigmoid(o)

    if self._max_unique_stats == 1:
      return next_hidden, (next_hidden, next_cell)
    else:
      return next_hidden, (next_hidden, next_cell, time_step + 1)

  def _create_batch_norm_variables(self, dtype):
    """Initialize the variables used for the `BatchNorm`s (if any)."""
    # The paper recommends a value of 0.1 for good gradient flow through the
    # tanh nonlinearity (although doesn't say whether this is for all gammas,
    # or just some).
    gamma_initializer = tf.constant_initializer(0.1)

    if self._use_batch_norm_h:
      self._gamma_h = tf.get_variable(
          self.GAMMA_H,
          shape=[4 * self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.GAMMA_H, gamma_initializer),
          partitioner=self._partitioners.get(self.GAMMA_H),
          regularizer=self._regularizers.get(self.GAMMA_H))
    if self._use_batch_norm_x:
      self._gamma_x = tf.get_variable(
          self.GAMMA_X,
          shape=[4 * self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.GAMMA_X, gamma_initializer),
          partitioner=self._partitioners.get(self.GAMMA_X),
          regularizer=self._regularizers.get(self.GAMMA_X))
    if self._use_batch_norm_c:
      self._gamma_c = tf.get_variable(
          self.GAMMA_C,
          shape=[self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.GAMMA_C, gamma_initializer),
          partitioner=self._partitioners.get(self.GAMMA_C),
          regularizer=self._regularizers.get(self.GAMMA_C))
      self._beta_c = tf.get_variable(
          self.BETA_C,
          shape=[self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.BETA_C),
          partitioner=self._partitioners.get(self.BETA_C),
          regularizer=self._regularizers.get(self.BETA_C))

  def _create_gate_variables(self, input_shape, dtype):
    """Initialize the variables used for the gates."""
    if len(input_shape) != 2:
      raise ValueError(
          "Rank of shape must be {} not: {}".format(2, len(input_shape)))
    input_size = input_shape.dims[1].value

    b_shape = [4 * self._hidden_size]

    equiv_input_size = self._hidden_size + input_size
    initializer = basic.create_linear_initializer(equiv_input_size)

    if self._use_batch_norm_h or self._use_batch_norm_x:
      self._w_h = tf.get_variable(
          self.W_GATES + "_H",
          shape=[self._hidden_size, 4 * self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.W_GATES, initializer),
          partitioner=self._partitioners.get(self.W_GATES),
          regularizer=self._regularizers.get(self.W_GATES))
      self._w_x = tf.get_variable(
          self.W_GATES + "_X",
          shape=[input_size, 4 * self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.W_GATES, initializer),
          partitioner=self._partitioners.get(self.W_GATES),
          regularizer=self._regularizers.get(self.W_GATES))
    else:
      self._w_xh = tf.get_variable(
          self.W_GATES,
          shape=[self._hidden_size + input_size, 4 * self._hidden_size],
          dtype=dtype,
          initializer=self._initializers.get(self.W_GATES, initializer),
          partitioner=self._partitioners.get(self.W_GATES),
          regularizer=self._regularizers.get(self.W_GATES))
    self._b = tf.get_variable(
        self.B_GATES,
        shape=b_shape,
        dtype=dtype,
        initializer=self._initializers.get(self.B_GATES, initializer),
        partitioner=self._partitioners.get(self.B_GATES),
        regularizer=self._regularizers.get(self.B_GATES))

  def _create_peephole_variables(self, dtype):
    """Initialize the variables used for the peephole connections."""
    self._w_f_diag = tf.get_variable(
        self.W_F_DIAG,
        shape=[self._hidden_size],
        dtype=dtype,
        initializer=self._initializers.get(self.W_F_DIAG),
        partitioner=self._partitioners.get(self.W_F_DIAG),
        regularizer=self._regularizers.get(self.W_F_DIAG))
    self._w_i_diag = tf.get_variable(
        self.W_I_DIAG,
        shape=[self._hidden_size],
        dtype=dtype,
        initializer=self._initializers.get(self.W_I_DIAG),
        partitioner=self._partitioners.get(self.W_I_DIAG),
        regularizer=self._regularizers.get(self.W_I_DIAG))
    self._w_o_diag = tf.get_variable(
        self.W_O_DIAG,
        shape=[self._hidden_size],
        dtype=dtype,
        initializer=self._initializers.get(self.W_O_DIAG),
        partitioner=self._partitioners.get(self.W_O_DIAG),
        regularizer=self._regularizers.get(self.W_O_DIAG))

  def initial_state(self, batch_size, dtype=tf.float32, trainable=False,
                    trainable_initializers=None, trainable_regularizers=None,
                    name=None):
    """Builds the default start state tensor of zeros.

    Args:
      batch_size: An int, float or scalar Tensor representing the batch size.
      dtype: The data type to use for the state.
      trainable: Boolean that indicates whether to learn the initial state.
      trainable_initializers: An optional pair of initializers for the
          initial hidden state and cell state.
      trainable_regularizers: Optional regularizer function or nested structure
        of functions with the same structure as the `state_size` property of the
        core, to be used as regularizers of the initial state variable. A
        regularizer should be a function that takes a single `Tensor` as an
        input and returns a scalar `Tensor` output, e.g. the L1 and L2
        regularizers in `tf.contrib.layers`.
      name: Optional string used to prefix the initial state variable names, in
          the case of a trainable initial state. If not provided, defaults to
          the name of the module.

    Returns:
      A tensor tuple `([batch_size, state_size], [batch_size, state_size], ?)`
      filled with zeros, with the third entry present when batch norm is enabled
      with `max_unique_stats > 1', with value `0` (representing the time step).
    """
    if self._max_unique_stats == 1:
      return super(BatchNormLSTM, self).initial_state(
          batch_size, dtype=dtype, trainable=trainable,
          trainable_initializers=trainable_initializers,
          trainable_regularizers=trainable_regularizers, name=name)
    else:
      with tf.name_scope(self._initial_state_scope(name)):
        if not trainable:
          state = self.zero_state(batch_size, dtype)
        else:
          # We have to manually create the state ourselves so we don't create a
          # variable that never gets used for the third entry.
          state = rnn_core.trainable_initial_state(
              batch_size,
              (tf.TensorShape([self._hidden_size]),
               tf.TensorShape([self._hidden_size])),
              dtype=dtype,
              initializers=trainable_initializers,
              regularizers=trainable_regularizers,
              name=self._initial_state_scope(name))
        return (state[0], state[1], tf.constant(0, dtype=tf.int32))

  @property
  def state_size(self):
    """Tuple of `tf.TensorShape`s indicating the size of state tensors."""
    if self._max_unique_stats == 1:
      return (tf.TensorShape([self._hidden_size]),
              tf.TensorShape([self._hidden_size]))
    else:
      return (tf.TensorShape([self._hidden_size]),
              tf.TensorShape([self._hidden_size]),
              tf.TensorShape(1))

  @property
  def output_size(self):
    """`tf.TensorShape` indicating the size of the core output."""
    return tf.TensorShape([self._hidden_size])

  @property
  def use_peepholes(self):
    """Boolean indicating whether peephole connections are used."""
    return self._use_peepholes

  @property
  def use_batch_norm_h(self):
    """Boolean indicating whether batch norm for hidden -> gates is enabled."""
    return self._use_batch_norm_h

  @property
  def use_batch_norm_x(self):
    """Boolean indicating whether batch norm for input -> gates is enabled."""
    return self._use_batch_norm_x

  @property
  def use_batch_norm_c(self):
    """Boolean indicating whether batch norm for cell -> output is enabled."""
    return self._use_batch_norm_c

  class IndexedStatsBatchNorm(base.AbstractModule):
    """BatchNorm module where batch statistics are selected by an input index.

    This is used by LSTM+batchnorm, where we have distinct batch norm statistics
    for the first `max_unique_stats` time steps, and then use the final set of
    statistics for subsequent time steps.

    The module has as input (x, index, is_training, test_local_stats). During
    training or when test_local_stats=True, the output is simply batchnorm(x)
    (where mean(x) and stddev(x) are used), and during training the
    `BatchNorm` module accumulates statistics in mean_i, etc, where
    i = min(index, max_unique_stats - 1).

    During testing with test_local_stats=False, the output is batchnorm(x),
    where mean_i and stddev_i are used instead of mean(x) and stddev(x).

    See the `BatchNorm` module for more on is_training and test_local_stats.

    No offset `beta` or scaling `gamma` are learnt.
    """

    def __init__(self, max_unique_stats, name=None):
      """Create an IndexedStatsBatchNorm.

      Args:
        max_unique_stats: number of different indices to have statistics for;
          indices beyond this will use the final statistics.
        name: Name of the module.
      """
      super(BatchNormLSTM.IndexedStatsBatchNorm, self).__init__(name=name)
      self._max_unique_stats = max_unique_stats

    def _build(self, inputs, index, is_training, test_local_stats):
      """Add the IndexedStatsBatchNorm module to the graph.

      Args:
        inputs: Tensor to apply batch norm to.
        index: Scalar TensorFlow int32 value to select the batch norm index.
        is_training: Boolean to indicate to `snt.BatchNorm` if we are
          currently training.
        test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
          normalization should  use local batch statistics at test time.

      Returns:
        Output of batch norm operation.
      """
      def create_batch_norm():
        return batch_norm.BatchNorm(offset=False, scale=False)(
            inputs, is_training, test_local_stats)

      if self._max_unique_stats > 1:
        pred_fn_pairs = [(tf.equal(i, index), create_batch_norm)
                         for i in xrange(self._max_unique_stats - 1)]
        out = tf.case(pred_fn_pairs, create_batch_norm)
        out.set_shape(inputs.get_shape())  # needed for tf.case shape inference
        return out
      else:
        return create_batch_norm()

  class CoreWithExtraBuildArgs(rnn_core.RNNCore):
    """Wraps an RNNCore so that the build method receives extra args and kwargs.

    This will pass the additional input `args` and `kwargs` to the _build
    function of the snt.RNNCore after the input and prev_state inputs.
    """

    def __init__(self, core, *args, **kwargs):
      """Construct the CoreWithExtraBuildArgs.

      Args:
        core: The snt.RNNCore to wrap.
        *args: Extra arguments to pass to _build.
        **kwargs: Extra keyword arguments to pass to _build.
      """
      super(BatchNormLSTM.CoreWithExtraBuildArgs, self).__init__(
          name=core.module_name + "_extra_args")
      self._core = core
      self._args = args
      self._kwargs = kwargs

    def _build(self, inputs, state):
      return self._core(inputs, state, *self._args, **self._kwargs)

    @property
    def state_size(self):
      """Tuple indicating the size of nested state tensors."""
      return self._core.state_size

    @property
    def output_size(self):
      """`tf.TensorShape` indicating the size of the core output."""
      return self._core.output_size


class ConvLSTM(rnn_core.RNNCore):
  """Convolutional LSTM."""

  @classmethod
  def get_possible_initializer_keys(cls, conv_ndims, use_bias=True):
    conv_class = cls._get_conv_class(conv_ndims)
    return conv_class.get_possible_initializer_keys(use_bias)

  @classmethod
  def _get_conv_class(cls, conv_ndims):
    if conv_ndims == 1:
      return conv.Conv1D
    elif conv_ndims == 2:
      return conv.Conv2D
    elif conv_ndims == 3:
      return conv.Conv3D
    else:
      raise ValueError("Invalid convolution dimensionality.")

  def __init__(self,
               conv_ndims,
               input_shape,
               output_channels,
               kernel_shape,
               stride=1,
               rate=1,
               padding=conv.SAME,
               use_bias=True,
               forget_bias=1.0,
               initializers=None,
               partitioners=None,
               regularizers=None,
               custom_getter=None,
               name="conv_lstm"):
    """Construct ConvLSTM.

    Args:
      conv_ndims: Convolution dimensionality (1, 2 or 3).
      input_shape: Shape of the input as an iterable, excluding the batch size.
      output_channels: Number of output channels of the conv LSTM.
      kernel_shape: Sequence of kernel sizes (of size 2), or integer that is
          used to define kernel size in all dimensions.
      stride: Sequence of kernel strides (of size 2), or integer that is used to
          define stride in all dimensions.
      rate: Sequence of dilation rates (of size conv_ndims), or integer that is
          used to define dilation rate in all dimensions. 1 corresponds to a
          standard convolution, while rate > 1 corresponds to a dilated
          convolution. Cannot be > 1 if any of stride is also > 1.
      padding: Padding algorithm, either `snt.SAME` or `snt.VALID`.
      use_bias: Use bias in convolutions.
      forget_bias: Forget bias.
      initializers: Dict containing ops to initialize the convolutional weights.
      partitioners: Optional dict containing partitioners to partition
        the convolutional weights and biases. As a default, no partitioners are
        used.
      regularizers: Optional dict containing regularizers for the convolutional
        weights and biases. As a default, no regularizers are used.
      custom_getter: Callable that takes as a first argument the true getter,
        and allows overwriting the internal get_variable method. See the
        `tf.get_variable` documentation for more details.
      name: Name of the module.

    Raises:
      ValueError: If `skip_connection` is `True` and stride is different from 1
        or if `input_shape` is incompatible with `conv_ndims`.
    """
    super(ConvLSTM, self).__init__(custom_getter=custom_getter, name=name)

    self._conv_class = self._get_conv_class(conv_ndims)

    if conv_ndims != len(input_shape)-1:
      raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
          input_shape, conv_ndims))

    self._conv_ndims = conv_ndims
    self._input_shape = tuple(input_shape)
    self._output_channels = output_channels
    self._kernel_shape = kernel_shape
    self._stride = stride
    self._rate = rate
    self._padding = padding
    self._use_bias = use_bias
    self._forget_bias = forget_bias
    self._initializers = initializers
    self._partitioners = partitioners
    self._regularizers = regularizers

    self._total_output_channels = output_channels
    if self._stride != 1:
      self._total_output_channels //= self._stride * self._stride

    self._convolutions = collections.defaultdict(self._new_convolution)

  def _new_convolution(self):
    return self._conv_class(
        output_channels=4*self._output_channels,
        kernel_shape=self._kernel_shape,
        stride=self._stride,
        rate=self._rate,
        padding=self._padding,
        use_bias=self._use_bias,
        initializers=self._initializers,
        partitioners=self._partitioners,
        regularizers=self._regularizers,
        name="conv")

  @property
  def convolutions(self):
    return self._convolutions

  @property
  def state_size(self):
    """Tuple of `tf.TensorShape`s indicating the size of state tensors."""
    hidden_size = tf.TensorShape(
        self._input_shape[:-1] + (self._output_channels,))
    return (hidden_size, hidden_size)

  @property
  def output_size(self):
    """`tf.TensorShape` indicating the size of the core output."""
    return tf.TensorShape(
        self._input_shape[:-1] + (self._total_output_channels,))

  def _build(self, inputs, state):
    hidden, cell = state
    input_conv = self._convolutions["input"]
    hidden_conv = self._convolutions["hidden"]
    next_hidden = input_conv(inputs) + hidden_conv(hidden)
    gates = tf.split(value=next_hidden, num_or_size_splits=4,
                     axis=self._conv_ndims+1)

    input_gate, next_input, forget_gate, output_gate = gates
    next_cell = tf.sigmoid(forget_gate + self._forget_bias) * cell
    next_cell += tf.sigmoid(input_gate) * tf.tanh(next_input)
    output = tf.tanh(next_cell) * tf.sigmoid(output_gate)
    return output, (output, next_cell)


class Conv1DLSTM(ConvLSTM):
  """1D convolutional LSTM."""

  @classmethod
  def get_possible_initializer_keys(cls, use_bias=True):
    return super(Conv1DLSTM, cls).get_possible_initializer_keys(1, use_bias)

  def __init__(self, name="conv_1d_lstm", **kwargs):
    """Construct Conv1DLSTM. See `snt.ConvLSTM` for more details."""
    super(Conv1DLSTM, self).__init__(conv_ndims=1, name=name, **kwargs)


class Conv2DLSTM(ConvLSTM):
  """2D convolutional LSTM."""


  @classmethod
  def get_possible_initializer_keys(cls, use_bias=True):
    return super(Conv2DLSTM, cls).get_possible_initializer_keys(2, use_bias)

  def __init__(self, name="conv_2d_lstm", **kwargs):
    """Construct Conv2DLSTM. See `snt.ConvLSTM` for more details."""
    super(Conv2DLSTM, self).__init__(conv_ndims=2, name=name, **kwargs)


class GRU(rnn_core.RNNCore):
  """GRU recurrent network cell.

  The implementation is based on: https://arxiv.org/pdf/1412.3555v1.pdf.

  Attributes:
    state_size: Integer indicating the size of state tensor.
    output_size: Integer indicating the size of the core output.
  """

  # Keys that may be provided for parameter initializers.
  WZ = "wz"  # weight for input -> update cell
  UZ = "uz"  # weight for prev_state -> update cell
  BZ = "bz"  # bias for update_cell
  WR = "wr"  # weight for input -> reset cell
  UR = "ur"  # weight for prev_state -> reset cell
  BR = "br"  # bias for reset cell
  WH = "wh"  # weight for input -> candidate activation
  UH = "uh"  # weight for prev_state -> candidate activation
  BH = "bh"  # bias for candidate activation
  POSSIBLE_INITIALIZER_KEYS = {WZ, UZ, BZ, WR, UR, BR, WH, UH, BH}
  # Keep old name for backwards compatibility

  POSSIBLE_KEYS = POSSIBLE_INITIALIZER_KEYS

  def __init__(self, hidden_size, initializers=None, partitioners=None,
               regularizers=None, custom_getter=None, name="gru"):
    """Construct GRU.

    Args:
      hidden_size: (int) Hidden size dimensionality.
      initializers: Dict containing ops to initialize the weights. This
        dict may contain any of the keys returned by
        `GRU.get_possible_initializer_keys`.
      partitioners: Optional dict containing partitioners to partition
        the weights and biases. As a default, no partitioners are used. This
        dict may contain any of the keys returned by
        `GRU.get_possible_initializer_keys`
      regularizers: Optional dict containing regularizers for the weights and
        biases. As a default, no regularizers are used. This
        dict may contain any of the keys returned by
        `GRU.get_possible_initializer_keys`
      custom_getter: Callable that takes as a first argument the true getter,
        and allows overwriting the internal get_variable method. See the
        `tf.get_variable` documentation for more details.
      name: Name of the module.

    Raises:
      KeyError: if `initializers` contains any keys not returned by
        `GRU.get_possible_initializer_keys`.
      KeyError: if `partitioners` contains any keys not returned by
        `GRU.get_possible_initializer_keys`.
      KeyError: if `regularizers` contains any keys not returned by
        `GRU.get_possible_initializer_keys`.
    """
    super(GRU, self).__init__(custom_getter=custom_getter, name=name)
    self._hidden_size = hidden_size
    self._initializers = util.check_initializers(
        initializers, self.POSSIBLE_INITIALIZER_KEYS)
    self._partitioners = util.check_partitioners(
        partitioners, self.POSSIBLE_INITIALIZER_KEYS)
    self._regularizers = util.check_regularizers(
        regularizers, self.POSSIBLE_INITIALIZER_KEYS)

  @classmethod
  def get_possible_initializer_keys(cls):
    """Returns the keys the dictionary of variable initializers may contain.

    The set of all possible initializer keys are:
      wz: weight for input -> update cell
      uz: weight for prev_state -> update cell
      bz: bias for update_cell
      wr: weight for input -> reset cell
      ur: weight for prev_state -> reset cell
      br: bias for reset cell
      wh: weight for input -> candidate activation
      uh: weight for prev_state -> candidate activation
      bh: bias for candidate activation

    Returns:
      Set with strings corresponding to the strings that may be passed to the
        constructor.
    """
    return super(GRU, cls).get_possible_initializer_keys(cls)

  def _build(self, inputs, prev_state):
    """Connects the GRU module into the graph.

    If this is not the first time the module has been connected to the graph,
    the Tensors provided as inputs and state must have the same final
    dimension, in order for the existing variables to be the correct size for
    their corresponding multiplications. The batch size may differ for each
    connection.

    Args:
      inputs: Tensor of size `[batch_size, input_size]`.
      prev_state: Tensor of size `[batch_size, hidden_size]`.

    Returns:
      A tuple (output, next_state) where `output` is a Tensor of size
      `[batch_size, hidden_size]` and `next_state` is a Tensor of size
      `[batch_size, hidden_size]`.

    Raises:
      ValueError: If connecting the module into the graph any time after the
        first time, and the inferred size of the inputs does not match previous
        invocations.
    """
    input_size = inputs.get_shape()[1]
    weight_shape = (input_size, self._hidden_size)
    u_shape = (self._hidden_size, self._hidden_size)
    bias_shape = (self._hidden_size,)

    self._wz = tf.get_variable(GRU.WZ, weight_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.WZ),
                               partitioner=self._partitioners.get(GRU.WZ),
                               regularizer=self._regularizers.get(GRU.WZ))
    self._uz = tf.get_variable(GRU.UZ, u_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.UZ),
                               partitioner=self._partitioners.get(GRU.UZ),
                               regularizer=self._regularizers.get(GRU.UZ))
    self._bz = tf.get_variable(GRU.BZ, bias_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.BZ),
                               partitioner=self._partitioners.get(GRU.BZ),
                               regularizer=self._regularizers.get(GRU.BZ))
    z = tf.sigmoid(tf.matmul(inputs, self._wz) +
                   tf.matmul(prev_state, self._uz) + self._bz)

    self._wr = tf.get_variable(GRU.WR, weight_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.WR),
                               partitioner=self._partitioners.get(GRU.WR),
                               regularizer=self._regularizers.get(GRU.WR))
    self._ur = tf.get_variable(GRU.UR, u_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.UR),
                               partitioner=self._partitioners.get(GRU.UR),
                               regularizer=self._regularizers.get(GRU.UR))
    self._br = tf.get_variable(GRU.BR, bias_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.BR),
                               partitioner=self._partitioners.get(GRU.BR),
                               regularizer=self._regularizers.get(GRU.BR))
    r = tf.sigmoid(tf.matmul(inputs, self._wr) +
                   tf.matmul(prev_state, self._ur) + self._br)

    self._wh = tf.get_variable(GRU.WH, weight_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.WH),
                               partitioner=self._partitioners.get(GRU.WH),
                               regularizer=self._regularizers.get(GRU.WH))
    self._uh = tf.get_variable(GRU.UH, u_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.UH),
                               partitioner=self._partitioners.get(GRU.UH),
                               regularizer=self._regularizers.get(GRU.UH))
    self._bh = tf.get_variable(GRU.BH, bias_shape, dtype=inputs.dtype,
                               initializer=self._initializers.get(GRU.BH),
                               partitioner=self._partitioners.get(GRU.BH),
                               regularizer=self._regularizers.get(GRU.BH))
    h_twiddle = tf.tanh(tf.matmul(inputs, self._wh) +
                        tf.matmul(r * prev_state, self._uh) + self._bh)

    state = (1 - z) * prev_state + z * h_twiddle
    return state, state

  @property
  def state_size(self):
    return tf.TensorShape([self._hidden_size])

  @property
  def output_size(self):
    return tf.TensorShape([self._hidden_size])


class HighwayCore(rnn_core.RNNCore):
  """Recurrent Highway Network cell.

  The implementation is based on: https://arxiv.org/pdf/1607.03474v5.pdf
  As per the first lines of section 5 of the reference paper, 1 - T is
  used instead of a dedicated C gate.

  Attributes:
    state_size: Integer indicating the size of state tensor.
    output_size: Integer indicating the size of the core output.
  """

  # Keys that may be provided for parameter initializers.
  WT = "wt"  # weight for input or previous state -> T gate
  BT = "bt"  # bias for previous state -> T gate
  WH = "wh"  # weight for input or previous state -> H gate
  BH = "bh"  # bias for previous state -> H gate

  def __init__(
      self,
      hidden_size,
      num_layers,
      initializers=None,
      partitioners=None,
      regularizers=None,
      custom_getter=None,
      name="highwaycore"):
    """Construct a new Recurrent Highway core.

    Args:
      hidden_size: (int) Hidden size dimensionality.
      num_layers: (int) Number of highway layers.
      initializers: Dict containing ops to initialize the weights. This
        dict may contain any of the keys returned by
        `HighwayCore.get_possible_initializer_keys`.
      partitioners: Optional dict containing partitioners to partition
        the weights and biases. As a default, no partitioners are used. This
        dict may contain any of the keys returned by
        `HighwayCore.get_possible_initializer_keys`.
      regularizers: Optional dict containing regularizers for the weights and
        biases. As a default, no regularizers are used. This
        dict may contain any of the keys returned by
        `HighwayCore.get_possible_initializer_keys`.
      custom_getter: Callable that takes as a first argument the true getter,
        and allows overwriting the internal get_variable method. See the
        `tf.get_variable` documentation for more details.
      name: Name of the module.

    Raises:
      KeyError: if `initializers` contains any keys not returned by
        `HighwayCore.get_possible_initializer_keys`.
      KeyError: if `partitioners` contains any keys not returned by
        `HighwayCore.get_possible_initializer_keys`.
      KeyError: if `regularizers` contains any keys not returned by
        `HighwayCore.get_possible_initializer_keys`.
    """
    super(HighwayCore, self).__init__(custom_getter=custom_getter, name=name)
    self._hidden_size = hidden_size
    self._num_layers = num_layers
    self._initializers = util.check_initializers(
        initializers, self.get_possible_initializer_keys(num_layers))
    self._partitioners = util.check_partitioners(
        partitioners, self.get_possible_initializer_keys(num_layers))
    self._regularizers = util.check_regularizers(
        regularizers, self.get_possible_initializer_keys(num_layers))

  @classmethod
  def get_possible_initializer_keys(cls, num_layers):
    """Returns the keys the dictionary of variable initializers may contain.

    The set of all possible initializer keys are:
      wt: weight for input -> T gate
      wh: weight for input -> H gate
      wtL: weight for prev state -> T gate for layer L (indexed from 0)
      whL: weight for prev state -> H gate for layer L (indexed from 0)
      btL: bias for prev state -> T gate for layer L (indexed from 0)
      bhL: bias for prev state -> H gate for layer L (indexed from 0)

    Args:
      num_layers: (int) Number of highway layers.
    Returns:
      Set with strings corresponding to the strings that may be passed to the
        constructor.
    """
    keys = [cls.WT, cls.WH]
    for layer_index in xrange(num_layers):
      layer_str = str(layer_index)
      keys += [
          cls.WT + layer_str,
          cls.BT + layer_str,
          cls.WH + layer_str,
          cls.BH + layer_str]
    return set(keys)

  def _build(self, inputs, prev_state):
    """Connects the highway core module into the graph.

    Args:
      inputs: Tensor of size `[batch_size, input_size]`.
      prev_state: Tensor of size `[batch_size, hidden_size]`.

    Returns:
      A tuple (output, next_state) where `output` is a Tensor of size
      `[batch_size, hidden_size]` and `next_state` is a Tensor of size
      `[batch_size, hidden_size]`.

    Raises:
      ValueError: If connecting the module into the graph any time after the
        first time, and the inferred size of the inputs does not match previous
        invocations.
    """
    input_size = inputs.get_shape()[1]
    weight_shape = (input_size, self._hidden_size)
    u_shape = (self._hidden_size, self._hidden_size)
    bias_shape = (self._hidden_size,)

    def _get_variable(name, shape):
      return tf.get_variable(
          name,
          shape,
          dtype=inputs.dtype,
          initializer=self._initializers.get(name),
          partitioner=self._partitioners.get(name),
          regularizer=self._regularizers.get(name))

    pre_highway_wt = _get_variable(self.WT, weight_shape)
    pre_highway_wh = _get_variable(self.WH, weight_shape)
    state = prev_state
    for layer_index in xrange(self._num_layers):
      layer_str = str(layer_index)
      layer_wt = _get_variable(self.WT + layer_str, u_shape)
      layer_bt = _get_variable(self.BT + layer_str, bias_shape)
      layer_wh = _get_variable(self.WH + layer_str, u_shape)
      layer_bh = _get_variable(self.BH + layer_str, bias_shape)
      linear_t = tf.matmul(state, layer_wt) + layer_bt
      linear_h = tf.matmul(state, layer_wh) + layer_bh
      if layer_index == 0:
        linear_t += tf.matmul(inputs, pre_highway_wt)
        linear_h += tf.matmul(inputs, pre_highway_wh)
      output_t = tf.sigmoid(linear_t)
      output_h = tf.tanh(linear_h)
      state = state * (1 - output_t) + output_h * output_t

    return state, state

  @property
  def state_size(self):
    return tf.TensorShape([self._hidden_size])

  @property
  def output_size(self):
    return tf.TensorShape([self._hidden_size])


def highway_core_with_recurrent_dropout(
    hidden_size,
    num_layers,
    keep_prob=0.5,
    **kwargs):
  """Highway core with recurrent dropout.

  Args:
    hidden_size: (int) Hidden size dimensionality.
    num_layers: (int) Number of highway layers.
    keep_prob: the probability to keep an entry when applying dropout.
    **kwargs: Extra keyword arguments to pass to the highway core.

  Returns:
    A tuple (train_core, test_core) where train_core is a higway core with
    recurrent dropout enabled to be used for training and test_core is the
    same highway core without recurrent dropout.
  """

  core = HighwayCore(hidden_size, num_layers, **kwargs)
  return RecurrentDropoutWrapper(core, keep_prob), core
