# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implement Seq2Seq Transformer model by TF official NLP library.

Model paper: https://arxiv.org/pdf/1706.03762.pdf
"""
import math

import tensorflow as tf
import tf_keras as keras
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search

EOS_ID = 1


class Seq2SeqTransformer(keras.Model):
  """Transformer model with Keras.

  Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf

  The Transformer model consists of an encoder and decoder. The input is an int
  sequence (or a batch of sequences). The encoder produces a continuous
  representation, and the decoder uses the encoder output to generate
  probabilities for the output sequence.
  """

  def __init__(self,
               enc_vocab_size=33708,
               dec_vocab_size=33708,
               embedding_width=512,
               dropout_rate=0.0,
               padded_decode=False,
               decode_max_length=None,
               extra_decode_length=0,
               beam_size=4,
               alpha=0.6,
               encoder_layer=None,
               decoder_layer=None,
               eos_id=EOS_ID,
               max_tree_depth=None,
               add_action_mask_to_inputs=False,
               **kwargs):
    """Initialize layers to build Transformer model.

    Args:
      vocab_size: Size of vocabulary.
      embedding_width: Size of hidden layer for embedding.
      dropout_rate: Dropout probability.
      padded_decode: Whether to max_sequence_length padding is used. If set
        False, max_sequence_length padding is not used.
      decode_max_length: maximum number of steps to decode a sequence.
      extra_decode_length: Beam search will run extra steps to decode.
      beam_size: Number of beams for beam search
      alpha: The strength of length normalization for beam search.
      encoder_layer: An initialized encoder layer.
      decoder_layer: An initialized decoder layer.
      eos_id: Id of end of sentence token.
      **kwargs: other keyword arguments.
    """
    super().__init__(**kwargs)
    self._enc_vocab_size = enc_vocab_size
    self._dec_vocab_size = dec_vocab_size
    self._embedding_width = embedding_width
    self._dropout_rate = dropout_rate
    self._padded_decode = padded_decode
    self._decode_max_length = decode_max_length
    self._extra_decode_length = extra_decode_length
    self._beam_size = beam_size
    self._alpha = alpha
    self._eos_id = eos_id
    self._max_tree_depth = max_tree_depth
    self._add_action_mask_to_inputs = add_action_mask_to_inputs
    self.return_cache = False
    self.enc_embedding_lookup = layers.OnDeviceEmbedding(
        vocab_size=self._enc_vocab_size,
        embedding_width=self._embedding_width,
        initializer=tf.random_normal_initializer(
            mean=0., stddev=self._embedding_width**-0.5),
        scale_factor=self._embedding_width**0.5)
    self.dec_embedding_lookup = layers.OnDeviceEmbedding(
        vocab_size=self._dec_vocab_size,
        embedding_width=self._embedding_width,
        initializer=tf.random_normal_initializer(
            mean=0., stddev=self._embedding_width ** -0.5),
        scale_factor=self._embedding_width ** 0.5)
    self.encoder_layer = encoder_layer
    self.decoder_layer = decoder_layer
    self.position_embedding = layers.RelativePositionEmbedding(
        hidden_size=self._embedding_width)
    self.tree_position_embedding = TreePositionEncodingPostProcessing(
        max_tree_depth=max_tree_depth, num_repeat=embedding_width // (max_tree_depth * 2))
    self.encoder_dropout = keras.layers.Dropout(rate=self._dropout_rate)
    self.decoder_dropout = keras.layers.Dropout(rate=self._dropout_rate)
    self.enc_candidate_ratio = self.add_weight(name='enc_candidate_ratio', dtype=tf.float32, initializer=tf.constant_initializer(0.1))
    self.dec_candidate_ratio = self.add_weight(name='dec_candidate_ratio', dtype=tf.float32, initializer=tf.constant_initializer(0.1))

  def get_config(self):
    config = {
        "enc_vocab_size": self._enc_vocab_size,
        "dec_vocab_size": self._dec_vocab_size,
        "hidden_size": self._embedding_width,
        "dropout_rate": self._dropout_rate,
        "padded_decode": self._padded_decode,
        "decode_max_length": self._decode_max_length,
        "eos_id": self._eos_id,
        "extra_decode_length": self._extra_decode_length,
        "beam_size": self._beam_size,
        "alpha": self._alpha,
        "encoder_layer": self.encoder_layer,
        "decoder_layer": self.decoder_layer,
    }
    base_config = super(Seq2SeqTransformer, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def _embedding_linear(self, embedding_matrix, x):
    """Uses embeddings as linear transformation weights."""
    embedding_matrix = tf.cast(embedding_matrix, dtype=self.compute_dtype)
    x = tf.cast(x, dtype=self.compute_dtype)
    batch_size = tf.shape(x)[0]
    length = tf.shape(x)[1]
    hidden_size = tf.shape(x)[2]
    vocab_size = tf.shape(embedding_matrix)[0]

    x = tf.reshape(x, [-1, hidden_size])
    logits = tf.matmul(x, embedding_matrix, transpose_b=True)

    return tf.reshape(logits, [batch_size, length, vocab_size])

  def _parse_inputs(self, inputs):
    """Parses the `call` inputs and returns an uniformed output."""
    sources = inputs.get("inputs", None)
    input_mask = inputs.get("input_masks", None)
    embedded = inputs.get("embedded_inputs", None)

    if sources is None and embedded is not None:
      embedded_inputs = embedded
      boolean_mask = input_mask
      input_shape = tf_utils.get_shape_list(embedded, expected_rank=3)
      source_dtype = embedded.dtype
    elif sources is not None:
      embedded_inputs = self.enc_embedding_lookup(sources)
      boolean_mask = tf.not_equal(sources, 0)
      input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
      source_dtype = sources.dtype
    else:
      raise KeyError(
          "The call method expects either `inputs` or `embedded_inputs` and "
          "`input_masks` as input features.")

    return embedded_inputs, boolean_mask, input_shape, source_dtype

  def call(self, inputs, training=None, mask=None, return_kv_cache=False, return_last_token=False):
    """Calculate target logits or inferred target sequences.

    Args:
      inputs: a dictionary of tensors.
        Feature `inputs` (optional): int tensor with shape
          `[batch_size, input_length]`.
        Feature `embedded_inputs` (optional): float tensor with shape
          `[batch_size, input_length, embedding_width]`.
        Feature `targets` (optional): None or int tensor with shape
          `[batch_size, target_length]`.
        Feature `input_masks` (optional): When providing the `embedded_inputs`,
          the dictionary must provide a boolean mask marking the filled time
          steps. The shape of the tensor is `[batch_size, input_length]`.
        Either `inputs` or `embedded_inputs` and `input_masks` must be present
        in the input dictionary. In the second case the projection of the
        integer tokens to the transformer embedding space is skipped and
        `input_masks` is expected to be present.

    Returns:
      If targets is defined, then return logits for each word in the target
      sequence, which is a float tensor with shape
      `(batch_size, target_length, vocab_size)`. If target is `None`, then
      generate output sequence one token at a time and
      returns a dictionary {
          outputs: `(batch_size, decoded_length)`
          scores: `(batch_size, 1)`}
      Even when `float16` is used, the output tensor(s) are always `float32`.

    Raises:
      NotImplementedError: If try to use padded decode method on CPU/GPUs.
    """
    # Prepare inputs to the layer stack by adding positional encodings and
    # applying dropout.
    targets = inputs.get("targets", None)
    cache = inputs.get("cache", None)
    # return_kv_cache = inputs.get("return_kv_cache", None)
    (embedded_inputs, boolean_mask, input_shape, source_dtype) = self._parse_inputs(inputs)
    enc_action_mask = inputs.get("enc_action_mask")
    dec_action_mask = inputs.get("dec_action_mask")
    if not cache:
        embedding_mask = tf.cast(boolean_mask, embedded_inputs.dtype)
        embedded_inputs *= tf.expand_dims(embedding_mask, -1)
        # Attention_mask generation.
        attention_mask = tf.cast(       # simply mask the [PAD]
            tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
            dtype=source_dtype)
        broadcast_ones = tf.ones(
            shape=[input_shape[0], input_shape[1], 1], dtype=source_dtype)
        attention_mask = broadcast_ones * attention_mask

        enc_pos_encoding = inputs.get("enc_pos_encoding", self.position_embedding(embedded_inputs))
        if self._max_tree_depth is not None:
            enc_pos_encoding = self.tree_position_embedding(enc_pos_encoding)
        enc_pos_encoding = tf.cast(enc_pos_encoding, embedded_inputs.dtype)
        encoder_inputs = embedded_inputs + enc_pos_encoding

        if enc_action_mask is not None and self._add_action_mask_to_inputs:
            enc_action_mask_float = tf.cast(enc_action_mask, dtype=self.compute_dtype)
            enc_candidate_actions_embeddings = tf.matmul(
                enc_action_mask_float,
                self.enc_embedding_lookup.embeddings) / tf.reduce_sum(enc_action_mask_float, axis=2, keepdims=True)
            encoder_inputs += self.enc_candidate_ratio * enc_candidate_actions_embeddings

        encoder_inputs = self.encoder_dropout(encoder_inputs)

        encoder_outputs = self.encoder_layer(
            encoder_inputs, attention_mask=attention_mask)
        kv_cache = None
    else:
        encoder_outputs = cache['encoder_outputs']
        kv_cache = cache['kv_cache']

    if targets is None:
      if self._padded_decode:
        max_decode_length = self._decode_max_length
      else:
        max_decode_length = self._decode_max_length or (
            tf.shape(encoder_outputs)[1] + self._extra_decode_length)
      symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)

      batch_size = tf.shape(encoder_outputs)[0]
      # Create initial set of IDs that will be passed to symbols_to_logits_fn.
      initial_ids = tf.zeros([batch_size], dtype=tf.int32)

      # Create cache storing decoder attention values for each layer.
      init_decode_length = (max_decode_length if self._padded_decode else 0)
      num_heads = self.decoder_layer.num_attention_heads
      dim_per_head = self._embedding_width // num_heads

      # Cache dtype needs to match beam_search dtype.
      # pylint: disable=g-complex-comprehension
      cache = {
          str(layer): {
              "key":
                  tf.zeros(
                      [batch_size, init_decode_length, num_heads, dim_per_head],
                      dtype=self.compute_dtype),
              "value":
                  tf.zeros(
                      [batch_size, init_decode_length, num_heads, dim_per_head],
                      dtype=self.compute_dtype)
          } for layer in range(self.decoder_layer.num_layers)
      }
      # pylint: enable=g-complex-comprehension

      # Add encoder output and attention bias to the cache.
      encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
      attention_mask = tf.cast(
          tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
          dtype=self.compute_dtype)
      cache["encoder_outputs"] = encoder_outputs
      cache["encoder_decoder_attention_mask"] = attention_mask

      # Use beam search to find the top beam_size sequences and scores.
      decoded_ids, scores = beam_search.sequence_beam_search(
          symbols_to_logits_fn=symbols_to_logits_fn,
          initial_ids=initial_ids,
          initial_cache=cache,
          vocab_size=self._dec_vocab_size,
          beam_size=self._beam_size,
          alpha=self._alpha,
          max_decode_length=max_decode_length,
          eos_id=self._eos_id,
          padded_decode=self._padded_decode,
          dtype=self.compute_dtype)

      # Get the top sequence for each batch element
      top_decoded_ids = decoded_ids[:, 0, 1:]
      top_scores = scores[:, 0]

      return {"outputs": top_decoded_ids, "scores": top_scores}

    if return_kv_cache:
        decoder_inputs = self.dec_embedding_lookup(targets)

        if self._add_action_mask_to_inputs:
            dec_action_mask_float = tf.cast(dec_action_mask, dtype=self.compute_dtype)
            dec_candidate_actions_embeddings = tf.matmul(
                dec_action_mask_float,
                self.dec_embedding_lookup.embeddings) / tf.reduce_sum(dec_action_mask_float, axis=2, keepdims=True)

            decoder_inputs += self.dec_candidate_ratio * dec_candidate_actions_embeddings

        dec_pos_encoding = inputs.get("dec_pos_encoding")
        if dec_pos_encoding is not None:
            if self._max_tree_depth is not None:
                dec_pos_encoding = self.tree_position_embedding(dec_pos_encoding)
        else:
            dec_pos_encoding = self.position_embedding(decoder_inputs)
        dec_pos_encoding = tf.cast(dec_pos_encoding, embedded_inputs.dtype)
        decoder_inputs += dec_pos_encoding

        decoder_inputs = self.decoder_dropout(decoder_inputs)

        decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
        batch_size = decoder_shape[0]
        decoder_length = decoder_shape[1]

        if kv_cache is None:
            num_heads = self.decoder_layer.num_attention_heads
            dim_per_head = self._embedding_width // num_heads
            kv_cache = {
                str(layer): {
                    "key":
                        tf.zeros(
                            [batch_size, 0, num_heads, dim_per_head],
                            dtype=self.compute_dtype),
                    "value":
                        tf.zeros(
                            [batch_size, 0, num_heads, dim_per_head],
                            dtype=self.compute_dtype)
                } for layer in range(self.decoder_layer.num_layers)
            }

        kv_length = tf.shape(kv_cache['0']['key'])[1]
        length = decoder_length + kv_length

        self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0)  # Lower triangular part
        self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
        self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])

        self_attention_mask = self_attention_mask[:, kv_length: length, :]

        attention_mask = tf.cast(tf.expand_dims(boolean_mask, axis=1), dtype=source_dtype)
        attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
    else:
        # training mode, shift targets to the right, and remove the last element (add [start] token)
        targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]

        decoder_inputs = self.dec_embedding_lookup(targets)

        if self._add_action_mask_to_inputs:
            dec_action_mask_float = tf.cast(dec_action_mask, dtype=self.compute_dtype)
            dec_candidate_actions_embeddings = tf.matmul(
                dec_action_mask_float,
            self.dec_embedding_lookup.embeddings) / tf.reduce_sum(dec_action_mask_float, axis=2, keepdims=True)

            decoder_inputs += self.dec_candidate_ratio * dec_candidate_actions_embeddings

        pos_encoding = inputs.get("dec_pos_encoding")
        if pos_encoding is not None:
            if self._max_tree_depth is not None:
                pos_encoding = self.tree_position_embedding(pos_encoding)
            pos_encoding = tf.pad(pos_encoding, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]  # for tree pos embedding
        else:
            pos_encoding = self.position_embedding(decoder_inputs)
        pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
        decoder_inputs += pos_encoding

        decoder_inputs = self.decoder_dropout(decoder_inputs)

        decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
        batch_size = decoder_shape[0]
        decoder_length = decoder_shape[1]

        length = tf.shape(decoder_inputs)[1]

        self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0)     # Lower triangular part
        self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
        self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])

        attention_mask = tf.cast(tf.expand_dims(boolean_mask, axis=1), dtype=source_dtype)
        attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])

    outputs = self.decoder_layer(
        decoder_inputs,# tf.expand_dims(decoder_inputs[:, -1, :], axis=1),
        encoder_outputs,
        cache=kv_cache,
        self_attention_mask=self_attention_mask,
        cross_attention_mask=attention_mask)
    logits = self._embedding_linear(self.dec_embedding_lookup.embeddings, outputs)
    # Model outputs should be float32 to avoid numeric issues.
    # https://www.tensorflow.org/guide/mixed_precision#building_the_model
    if return_last_token:
        logits = logits[:, -1, :]
    logits = tf.cast(logits, tf.float32)
    if return_kv_cache:
        dec_action_mask = dec_action_mask[:, 0, :]  # [batch_size, 1, vocab_size], remove 1
    logits = tf.where(dec_action_mask, logits, tf.float32.min)
    if self.return_cache:
        return logits, {'encoder_outputs': encoder_outputs, 'outputs': outputs, 'kv_cache': kv_cache}
    return logits

  def _get_symbols_to_logits_fn(self, max_decode_length):
    """Returns a decoding function that calculates logits of the next tokens."""
    timing_signal = self.position_embedding(
        inputs=None, length=max_decode_length + 1)
    timing_signal = tf.cast(timing_signal, dtype=self.compute_dtype)
    decoder_self_attention_mask = tf.linalg.band_part(
        tf.ones([max_decode_length, max_decode_length],
                dtype=self.compute_dtype), -1, 0)
    decoder_self_attention_mask = tf.reshape(
        decoder_self_attention_mask, [1, max_decode_length, max_decode_length])

    def symbols_to_logits_fn(ids, i, cache):
      """Generate logits for next potential IDs.

      Args:
        ids: Current decoded sequences. int tensor with shape `(batch_size *
          beam_size, i + 1)`.
        i: Loop index.
        cache: Dictionary of values storing the encoder output, encoder-decoder
          attention bias, and previous decoder attention values.

      Returns:
        Tuple of
          (logits with shape `(batch_size * beam_size, vocab_size)`,
           updated cache values)
      """
      # Set decoder input to the last generated IDs
      decoder_input = ids[:, -1:]

      # Preprocess decoder input by getting embeddings and adding timing signal.
      decoder_input = self.dec_embedding_lookup(decoder_input)
      decoder_input += timing_signal[i]
      if self._padded_decode:
        # indexing does not work on TPU.
        bias_shape = decoder_self_attention_mask.shape.as_list()
        self_attention_mask = tf.slice(decoder_self_attention_mask, [0, i, 0],
                                       [bias_shape[0], 1, bias_shape[2]])
      else:
        self_attention_mask = decoder_self_attention_mask[:, i:i + 1, :i + 1]
      decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
      batch_size = decoder_shape[0]
      decoder_length = decoder_shape[1]

      self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
      attention_mask = cache.get("encoder_decoder_attention_mask")
      attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])

      decoder_outputs = self.decoder_layer(
          decoder_input,
          cache.get("encoder_outputs"),
          self_attention_mask=self_attention_mask,
          cross_attention_mask=attention_mask,
          cache=cache,
          decode_loop_step=i if self._padded_decode else None)

      decoder_outputs = tf.cast(decoder_outputs, dtype=self.compute_dtype)
      logits = self._embedding_linear(self.dec_embedding_lookup.embeddings,
                                      decoder_outputs)
      logits = tf.squeeze(logits, axis=[1])
      return logits, cache

    return symbols_to_logits_fn


class TreePositionEncodingPostProcessing(keras.layers.Layer):
    def __init__(self, max_tree_depth, num_repeat, **kwargs):
        super().__init__(**kwargs)
        self._max_tree_depth = max_tree_depth
        self._num_repeat = num_repeat
        self.p = self.add_weight(name="p", shape=(num_repeat, ), dtype=tf.float32,
                                 initializer=tf.constant_initializer(2))

    def call(self, inputs, *args, **kwargs):    # [batch_size, max_length, max_tree_depth * 2]
        p = tf.tanh(self.p)  # scale p from -1 to 1
        p_vector = []
        for i in range(self._max_tree_depth):
            p_vector.append(tf.math.pow(p, i))
        p_vector = tf.stack(p_vector, axis=1)  # [num_repeat, max_tree_depth]
        p_vector = p_vector * tf.expand_dims(tf.math.sqrt(1 - tf.math.square(p)), axis=1)
        p_vector = tf.reshape(p_vector, shape=(self._max_tree_depth * self._num_repeat, ))
        p_vector = tf.reshape(tf.repeat(p_vector, repeats=2), shape=(1, 1, self._max_tree_depth * 2 * self._num_repeat))
        # tf.print(p)
        # tf.print(p_vector)
        inputs = tf.tile(inputs, [1, 1, self._num_repeat]) * p_vector * math.sqrt(self._max_tree_depth * self._num_repeat)
        return inputs


class TransformerEncoder(keras.layers.Layer):
  """Transformer encoder.

  Transformer encoder is made up of N identical layers. Each layer is composed
  of the sublayers:
    1. Self-attention layer
    2. Feedforward network (which is 2 fully-connected layers)
  """

  def __init__(self,
               num_layers=6,
               num_attention_heads=8,
               intermediate_size=2048,
               activation="relu",
               dropout_rate=0.0,
               attention_dropout_rate=0.0,
               use_bias=False,
               norm_first=True,
               norm_epsilon=1e-6,
               intermediate_dropout=0.0,
               **kwargs):
    """Initialize a Transformer encoder.

    Args:
      num_layers: Number of layers.
      num_attention_heads: Number of attention heads.
      intermediate_size: Size of the intermediate (Feedforward) layer.
      activation: Activation for the intermediate layer.
      dropout_rate: Dropout probability.
      attention_dropout_rate: Dropout probability for attention layers.
      use_bias: Whether to enable use_bias in attention layer. If set False,
        use_bias in attention layer is disabled.
      norm_first: Whether to normalize inputs to attention and intermediate
        dense layers. If set False, output of attention and intermediate dense
        layers is normalized.
      norm_epsilon: Epsilon value to initialize normalization layers.
      intermediate_dropout: Dropout probability for intermediate_dropout_layer.
      **kwargs: key word arguemnts passed to tf.keras.layers.Layer.
    """

    super(TransformerEncoder, self).__init__(**kwargs)
    self.num_layers = num_layers
    self.num_attention_heads = num_attention_heads
    self._intermediate_size = intermediate_size
    self._activation = activation
    self._dropout_rate = dropout_rate
    self._attention_dropout_rate = attention_dropout_rate
    self._use_bias = use_bias
    self._norm_first = norm_first
    self._norm_epsilon = norm_epsilon
    self._intermediate_dropout = intermediate_dropout

  def build(self, input_shape):
    """Implements build() for the layer."""
    self.encoder_layers = []
    for i in range(self.num_layers):
      self.encoder_layers.append(
          layers.TransformerEncoderBlock(
              num_attention_heads=self.num_attention_heads,
              inner_dim=self._intermediate_size,
              inner_activation=self._activation,
              output_dropout=self._dropout_rate,
              attention_dropout=self._attention_dropout_rate,
              use_bias=self._use_bias,
              norm_first=self._norm_first,
              norm_epsilon=self._norm_epsilon,
              inner_dropout=self._intermediate_dropout,
              attention_initializer=attention_initializer(input_shape[2]),
              name=("layer_%d" % i)))
    self.output_normalization = keras.layers.LayerNormalization(
        epsilon=self._norm_epsilon, dtype="float32")
    super(TransformerEncoder, self).build(input_shape)

  def get_config(self):
    config = {
        "num_layers": self.num_layers,
        "num_attention_heads": self.num_attention_heads,
        "intermediate_size": self._intermediate_size,
        "activation": self._activation,
        "dropout_rate": self._dropout_rate,
        "attention_dropout_rate": self._attention_dropout_rate,
        "use_bias": self._use_bias,
        "norm_first": self._norm_first,
        "norm_epsilon": self._norm_epsilon,
        "intermediate_dropout": self._intermediate_dropout
    }
    base_config = super(TransformerEncoder, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, encoder_inputs, attention_mask=None):
    """Return the output of the encoder.

    Args:
      encoder_inputs: A tensor with shape `(batch_size, input_length,
        hidden_size)`.
      attention_mask: A mask for the encoder self-attention layer with shape
        `(batch_size, input_length, input_length)`.

    Returns:
      Output of encoder which is a `float32` tensor with shape
        `(batch_size, input_length, hidden_size)`.
    """
    for layer_idx in range(self.num_layers):
      encoder_inputs = self.encoder_layers[layer_idx](
          [encoder_inputs, attention_mask])

    output_tensor = encoder_inputs
    output_tensor = self.output_normalization(output_tensor)

    return output_tensor


class TransformerDecoder(keras.layers.Layer):
  """Transformer decoder.

  Like the encoder, the decoder is made up of N identical layers.
  Each layer is composed of the sublayers:
    1. Self-attention layer
    2. Multi-headed attention layer combining encoder outputs with results from
       the previous self-attention layer.
    3. Feedforward network (2 fully-connected layers)
  """

  def __init__(self,
               num_layers=6,
               num_attention_heads=8,
               intermediate_size=2048,
               activation="relu",
               dropout_rate=0.0,
               attention_dropout_rate=0.0,
               use_bias=False,
               norm_first=True,
               norm_epsilon=1e-6,
               intermediate_dropout=0.0,
               **kwargs):
    """Initialize a Transformer decoder.

    Args:
      num_layers: Number of layers.
      num_attention_heads: Number of attention heads.
      intermediate_size: Size of the intermediate (Feedforward) layer.
      activation: Activation for the intermediate layer.
      dropout_rate: Dropout probability.
      attention_dropout_rate: Dropout probability for attention layers.
      use_bias: Whether to enable use_bias in attention layer. If set `False`,
        use_bias in attention layer is disabled.
      norm_first: Whether to normalize inputs to attention and intermediate
        dense layers. If set `False`, output of attention and intermediate dense
        layers is normalized.
      norm_epsilon: Epsilon value to initialize normalization layers.
      intermediate_dropout: Dropout probability for intermediate_dropout_layer.
      **kwargs: key word arguemnts passed to tf.keras.layers.Layer.
    """
    super(TransformerDecoder, self).__init__(**kwargs)
    self.num_layers = num_layers
    self.num_attention_heads = num_attention_heads
    self._intermediate_size = intermediate_size
    self._activation = activation
    self._dropout_rate = dropout_rate
    self._attention_dropout_rate = attention_dropout_rate
    self._use_bias = use_bias
    self._norm_first = norm_first
    self._norm_epsilon = norm_epsilon
    self._intermediate_dropout = intermediate_dropout

  def build(self, input_shape):
    """Implements build() for the layer."""
    self.decoder_layers = []
    for i in range(self.num_layers):
      self.decoder_layers.append(
          layers.TransformerDecoderBlock(
              num_attention_heads=self.num_attention_heads,
              intermediate_size=self._intermediate_size,
              intermediate_activation=self._activation,
              dropout_rate=self._dropout_rate,
              attention_dropout_rate=self._attention_dropout_rate,
              use_bias=self._use_bias,
              norm_first=self._norm_first,
              norm_epsilon=self._norm_epsilon,
              intermediate_dropout=self._intermediate_dropout,
              attention_initializer=attention_initializer(input_shape[2]),
              name=("layer_%d" % i)))
    self.output_normalization = keras.layers.LayerNormalization(
        epsilon=1e-6, dtype="float32")
    super(TransformerDecoder, self).build(input_shape)

  def get_config(self):
    config = {
        "num_layers": self.num_layers,
        "num_attention_heads": self.num_attention_heads,
        "intermediate_size": self._intermediate_size,
        "activation": self._activation,
        "dropout_rate": self._dropout_rate,
        "attention_dropout_rate": self._attention_dropout_rate,
        "use_bias": self._use_bias,
        "norm_first": self._norm_first,
        "norm_epsilon": self._norm_epsilon,
        "intermediate_dropout": self._intermediate_dropout
    }
    base_config = super(TransformerDecoder, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self,
           target,
           memory,
           self_attention_mask=None,
           cross_attention_mask=None,
           cache=None,
           decode_loop_step=None,
           return_all_decoder_outputs=False):
    """Return the output of the decoder layer stacks.

    Args:
      target: A tensor with shape `(batch_size, target_length, hidden_size)`.
      memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
      self_attention_mask: A tensor with shape `(batch_size, target_len,
        target_length)`, the mask for decoder self-attention layer.
      cross_attention_mask: A tensor with shape `(batch_size, target_length,
        input_length)` which is the mask for encoder-decoder attention layer.
      cache: (Used for fast decoding) A nested dictionary storing previous
        decoder self-attention values. The items are:
        {layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
                   "v": A tensor with shape `(batch_size, i, value_channels)`},
                     ...}
      decode_loop_step: An integer, the step number of the decoding loop. Used
        only for autoregressive inference on TPU.
      return_all_decoder_outputs: Return all decoder layer outputs.
        Note that the outputs are layer normed.
        This is useful when introducing per layer auxiliary loss.

    Returns:
      Output of decoder.
      float32 tensor with shape `(batch_size, target_length, hidden_size`).
    """

    output_tensor = target
    decoder_outputs = []
    for layer_idx in range(self.num_layers):
      transformer_inputs = [
          output_tensor, memory, cross_attention_mask, self_attention_mask
      ]
      # Gets the cache for decoding.
      if cache is None:
        output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
      else:
        cache_layer_idx = str(layer_idx)
        output_tensor, cache[cache_layer_idx] = self.decoder_layers[layer_idx](
            transformer_inputs,
            cache=cache[cache_layer_idx],
            decode_loop_step=decode_loop_step)
      if return_all_decoder_outputs:
        decoder_outputs.append(self.output_normalization(output_tensor))

    if return_all_decoder_outputs:
      return decoder_outputs
    else:
      return self.output_normalization(output_tensor)


def attention_initializer(hidden_size):
  """Initializer for attention layers in Seq2SeqTransformer."""
  hidden_size = int(hidden_size)
  limit = math.sqrt(6.0 / (hidden_size + hidden_size))
  return keras.initializers.RandomUniform(minval=-limit, maxval=limit)


def masked_loss(label, pred):
    mask = label != 0
    loss_object = keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
    return loss


def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match) / tf.reduce_sum(mask)


class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super().__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    step = tf.cast(step, dtype=tf.float32)
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)
    lr = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    return lr
