# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
IMPORTANT: This code is taken directly from Tensorflow
(https://github.com/tensorflow/tensorflow) and is copied temporarily
until it is available in a packaged Tensorflow version on pypi.

TODO(dennybritz): Delete this code when it becomes available in TF.

A library of helpers for use with SamplingDecoders.
"""

# pylint: skip-file

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

import six

from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest

from seq2seq.contrib.seq2seq import decoder

__all__ = [
    "Helper",
    "TrainingHelper",
    "GreedyEmbeddingHelper",
    "CustomHelper",
    "ScheduledEmbeddingTrainingHelper",
    "ScheduledOutputTrainingHelper",
]

_transpose_batch_time = decoder._transpose_batch_time  # pylint: disable=protected-access


def _unstack_ta(inp):
  return tensor_array_ops.TensorArray(
      dtype=inp.dtype, size=array_ops.shape(inp)[0],
      element_shape=inp.get_shape()[1:]).unstack(inp)


@six.add_metaclass(abc.ABCMeta)
class Helper(object):
  """Helper interface.  Helper instances are used by SamplingDecoder."""

  @abc.abstractproperty
  def batch_size(self):
    """Returns a scalar int32 tensor."""
    raise NotImplementedError("batch_size has not been implemented")

  @abc.abstractmethod
  def initialize(self, name=None):
    """Returns `(initial_finished, initial_inputs)`."""
    pass

  @abc.abstractmethod
  def sample(self, time, outputs, state, name=None):
    """Returns `sample_ids`."""
    pass

  @abc.abstractmethod
  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    """Returns `(finished, next_inputs, next_state)`."""
    pass


class CustomHelper(Helper):
  """Base abstract class that allows the user to customize sampling."""

  def __init__(self, initialize_fn, sample_fn, next_inputs_fn):
    """Initializer.

    Args:
      initialize_fn: callable that returns `(finished, next_inputs)`
        for the first iteration.
      sample_fn: callable that takes `(time, outputs, state)`
        and emits tensor `sample_ids`.
      next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)`
        and emits `(finished, next_inputs, next_state)`.
    """
    self._initialize_fn = initialize_fn
    self._sample_fn = sample_fn
    self._next_inputs_fn = next_inputs_fn
    self._batch_size = None

  @property
  def batch_size(self):
    if self._batch_size is None:
      raise ValueError("batch_size accessed before initialize was called")
    return self._batch_size

  def initialize(self, name=None):
    with ops.name_scope(name, "%sInitialize" % type(self).__name__):
      (finished, next_inputs) = self._initialize_fn()
      if self._batch_size is None:
        self._batch_size = array_ops.size(finished)
    return (finished, next_inputs)

  def sample(self, time, outputs, state, name=None):
    with ops.name_scope(
        name, "%sSample" % type(self).__name__, (time, outputs, state)):
      return self._sample_fn(time=time, outputs=outputs, state=state)

  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    with ops.name_scope(
        name, "%sNextInputs" % type(self).__name__, (time, outputs, state)):
      return self._next_inputs_fn(
          time=time, outputs=outputs, state=state, sample_ids=sample_ids)


class TrainingHelper(Helper):
  """A helper for use during training.  Only reads inputs.

  Returned sample_ids are the argmax of the RNN output logits.
  """

  def __init__(self, inputs, sequence_length, time_major=False, name=None):
    """Initializer.

    Args:
      inputs: A (structure of) input tensors.
      sequence_length: An int32 vector tensor.
      time_major: Python bool.  Whether the tensors in `inputs` are time major.
        If `False` (default), they are assumed to be batch major.
      name: Name scope for any created operations.

    Raises:
      ValueError: if `sequence_length` is not a 1D tensor.
    """
    with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
      inputs = ops.convert_to_tensor(inputs, name="inputs")
      if not time_major:
        inputs = nest.map_structure(_transpose_batch_time, inputs)

      self._input_tas = nest.map_structure(_unstack_ta, inputs)
      self._sequence_length = ops.convert_to_tensor(
          sequence_length, name="sequence_length")
      if self._sequence_length.get_shape().ndims != 1:
        raise ValueError(
            "Expected sequence_length to be a vector, but received shape: %s" %
            self._sequence_length.get_shape())

      self._zero_inputs = nest.map_structure(
          lambda inp: array_ops.zeros_like(inp[0, :]), inputs)

      self._batch_size = array_ops.size(sequence_length)

  @property
  def batch_size(self):
    return self._batch_size

  def initialize(self, name=None):
    with ops.name_scope(name, "TrainingHelperInitialize"):
      finished = math_ops.equal(0, self._sequence_length)
      all_finished = math_ops.reduce_all(finished)
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: self._zero_inputs,
          lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
      return (finished, next_inputs)

  def sample(self, time, outputs, name=None, **unused_kwargs):
    with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
      sample_ids = math_ops.cast(
          math_ops.argmax(outputs, axis=-1), dtypes.int32)
      return sample_ids

  def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
    """next_inputs_fn for TrainingHelper."""
    with ops.name_scope(name, "TrainingHelperNextInputs",
                        [time, outputs, state]):
      next_time = time + 1
      finished = (next_time >= self._sequence_length)
      all_finished = math_ops.reduce_all(finished)
      def read_from_ta(inp):
        return inp.read(next_time)
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: self._zero_inputs,
          lambda: nest.map_structure(read_from_ta, self._input_tas))
      return (finished, next_inputs, state)


class ScheduledEmbeddingTrainingHelper(TrainingHelper):
  """A training helper that adds scheduled sampling.

  Returns -1s for sample_ids where no sampling took place; valid sample id
  values elsewhere.
  """

  def __init__(self, inputs, sequence_length, embedding, sampling_probability,
               time_major=False, seed=None, scheduling_seed=None, name=None):
    """Initializer.

    Args:
      inputs: A (structure of) input tensors.
      sequence_length: An int32 vector tensor.
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`.
      sampling_probability: A 0D `float32` tensor: the probability of sampling
        categorically from the output ids instead of reading directly from the
        inputs.
      time_major: Python bool.  Whether the tensors in `inputs` are time major.
        If `False` (default), they are assumed to be batch major.
      seed: The sampling seed.
      scheduling_seed: The schedule decision rule sampling seed.
      name: Name scope for any created operations.

    Raises:
      ValueError: if `sampling_probability` is not a scalar or vector.
    """
    with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper",
                        [embedding, sampling_probability]):
      if callable(embedding):
        self._embedding_fn = embedding
      else:
        self._embedding_fn = (
            lambda ids: embedding_ops.embedding_lookup(embedding, ids))
      self._sampling_probability = ops.convert_to_tensor(
          sampling_probability, name="sampling_probability")
      if self._sampling_probability.get_shape().ndims not in (0, 1):
        raise ValueError(
            "sampling_probability must be either a scalar or a vector. "
            "saw shape: %s" % (self._sampling_probability.get_shape()))
      self._seed = seed
      self._scheduling_seed = scheduling_seed
      super(ScheduledEmbeddingTrainingHelper, self).__init__(
          inputs=inputs,
          sequence_length=sequence_length,
          time_major=time_major,
          name=name)

  def initialize(self, name=None):
    return super(ScheduledEmbeddingTrainingHelper, self).initialize(name=name)

  def sample(self, time, outputs, state, name=None):
    with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
                        [time, outputs, state]):
      # Return -1s where we did not sample, and sample_ids elsewhere
      select_sample_noise = random_ops.random_uniform(
          [self.batch_size], seed=self._scheduling_seed)
      select_sample = (self._sampling_probability > select_sample_noise)
      sample_id_sampler = categorical.Categorical(logits=outputs)
      return array_ops.where(
          select_sample,
          sample_id_sampler.sample(seed=self._seed),
          array_ops.tile([-1], [self.batch_size]))

  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
                        [time, outputs, state, sample_ids]):
      (finished, base_next_inputs, state) = (
          super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
              time=time,
              outputs=outputs,
              state=state,
              sample_ids=sample_ids,
              name=name))

      def maybe_sample():
        """Perform scheduled sampling."""
        where_sampling = math_ops.cast(
            array_ops.where(sample_ids > -1), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(sample_ids <= -1), dtypes.int32)
        where_sampling_flat = array_ops.reshape(where_sampling, [-1])
        where_not_sampling_flat = array_ops.reshape(where_not_sampling, [-1])
        sample_ids_sampling = array_ops.gather(sample_ids, where_sampling_flat)
        inputs_not_sampling = array_ops.gather(
            base_next_inputs, where_not_sampling_flat)
        sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))

      all_finished = math_ops.reduce_all(finished)
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: base_next_inputs, maybe_sample)
      return (finished, next_inputs, state)


class ScheduledOutputTrainingHelper(TrainingHelper):
  """A training helper that adds scheduled sampling directly to outputs.

  Returns False for sample_ids where no sampling took place; True elsewhere.
  """

  def __init__(self, inputs, sequence_length, sampling_probability,
               time_major=False, seed=None, next_input_layer=None,
               auxiliary_inputs=None, name=None):
    """Initializer.

    Args:
      inputs: A (structure) of input tensors.
      sequence_length: An int32 vector tensor.
      sampling_probability: A 0D `float32` tensor: the probability of sampling
        from the outputs instead of reading directly from the inputs.
      time_major: Python bool.  Whether the tensors in `inputs` are time major.
        If `False` (default), they are assumed to be batch major.
      seed: The sampling seed.
      next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
        `tf.layers.Dense`.  Optional layer to apply to the RNN output to create
        the next input.
      auxiliary_inputs: An optional (structure of) auxiliary input tensors with
        a shape that matches `inputs` in all but (potentially) the final
        dimension. These tensors will be concatenated to the sampled output or
        the `inputs` when not sampling for use as the next input.
      name: Name scope for any created operations.

    Raises:
      ValueError: if `sampling_probability` is not a scalar or vector.
    """
    with ops.name_scope(name, "ScheduledOutputTrainingHelper",
                        [inputs, auxiliary_inputs, sampling_probability]):
      self._sampling_probability = ops.convert_to_tensor(
          sampling_probability, name="sampling_probability")
      if self._sampling_probability.get_shape().ndims not in (0, 1):
        raise ValueError(
            "sampling_probability must be either a scalar or a vector. "
            "saw shape: %s" % (self._sampling_probability.get_shape()))

      if auxiliary_inputs is None:
        maybe_concatenated_inputs = inputs
      else:
        inputs = ops.convert_to_tensor(inputs, name="inputs")
        auxiliary_inputs = ops.convert_to_tensor(
            auxiliary_inputs, name="auxiliary_inputs")
        maybe_concatenated_inputs = nest.map_structure(
            lambda x, y: array_ops.concat((x, y), -1),
            inputs, auxiliary_inputs)
        if not time_major:
          auxiliary_inputs = nest.map_structure(
              _transpose_batch_time, auxiliary_inputs)

      self._auxiliary_input_tas = (
          nest.map_structure(_unstack_ta, auxiliary_inputs)
          if auxiliary_inputs is not None else None)

      self._seed = seed

      if (next_input_layer is not None and not isinstance(next_input_layer,
                                                          layers_base._Layer)):  # pylint: disable=protected-access
        raise TypeError("next_input_layer must be a Layer, received: %s" %
                        type(next_input_layer))
      self._next_input_layer = next_input_layer

      super(ScheduledOutputTrainingHelper, self).__init__(
          inputs=maybe_concatenated_inputs,
          sequence_length=sequence_length,
          time_major=time_major,
          name=name)

  def initialize(self, name=None):
    return super(ScheduledOutputTrainingHelper, self).initialize(name=name)

  def sample(self, time, outputs, state, name=None):
    with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
                        [time, outputs, state]):
      sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
      return math_ops.cast(
          sampler.sample(sample_shape=self.batch_size, seed=self._seed),
          dtypes.bool)

  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
                        [time, outputs, state, sample_ids]):
      (finished, base_next_inputs, state) = (
          super(ScheduledOutputTrainingHelper, self).next_inputs(
              time=time,
              outputs=outputs,
              state=state,
              sample_ids=sample_ids,
              name=name))

      def maybe_sample():
        """Perform scheduled sampling."""

        def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
          """Concatenate outputs with auxiliary inputs, if they exist."""
          if self._auxiliary_input_tas is None:
            return outputs_

          next_time = time + 1
          auxiliary_inputs = nest.map_structure(
              lambda ta: ta.read(next_time), self._auxiliary_input_tas)
          if indices is not None:
            auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
          return nest.map_structure(
              lambda x, y: array_ops.concat((x, y), -1),
              outputs_, auxiliary_inputs)

        if self._next_input_layer is None:
          return array_ops.where(
              sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
              base_next_inputs)

        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
            self._next_input_layer(outputs_sampling), where_sampling)

        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))

      all_finished = math_ops.reduce_all(finished)
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: base_next_inputs, maybe_sample)
      return (finished, next_inputs, state)


class GreedyEmbeddingHelper(Helper):
  """A helper for use during inference.

  Uses the argmax of the output (treated as logits) and passes the
  result through an embedding layer to get the next input.
  """

  def __init__(self, embedding, start_tokens, end_token):
    """Initializer.

    Args:
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`.
      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
      end_token: `int32` scalar, the token that marks end of decoding.

    Raises:
      ValueError: if `sequence_length` is not a 1D tensor.
    """
    if callable(embedding):
      self._embedding_fn = embedding
    else:
      self._embedding_fn = (
          lambda ids: embedding_ops.embedding_lookup(embedding, ids))

    self._start_tokens = ops.convert_to_tensor(
        start_tokens, dtype=dtypes.int32, name="start_tokens")
    self._end_token = ops.convert_to_tensor(
        end_token, dtype=dtypes.int32, name="end_token")
    if self._start_tokens.get_shape().ndims != 1:
      raise ValueError("start_tokens must be a vector")
    self._batch_size = array_ops.size(start_tokens)
    if self._end_token.get_shape().ndims != 0:
      raise ValueError("end_token must be a scalar")
    self._start_inputs = self._embedding_fn(self._start_tokens)

  @property
  def batch_size(self):
    return self._batch_size

  def initialize(self, name=None):
    finished = array_ops.tile([False], [self._batch_size])
    return (finished, self._start_inputs)

  def sample(self, time, outputs, state, name=None):
    """sample for GreedyEmbeddingHelper."""
    del time, state  # unused by sample_fn
    # Outputs are logits, use argmax to get the most probable id
    if not isinstance(outputs, ops.Tensor):
      raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                      type(outputs))
    sample_ids = math_ops.cast(
        math_ops.argmax(outputs, axis=-1), dtypes.int32)
    return sample_ids

  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    """next_inputs_fn for GreedyEmbeddingHelper."""
    del time, outputs  # unused by next_inputs_fn
    finished = math_ops.equal(sample_ids, self._end_token)
    all_finished = math_ops.reduce_all(finished)
    next_inputs = control_flow_ops.cond(
        all_finished,
        # If we're finished, the next_inputs value doesn't matter
        lambda: self._start_inputs,
        lambda: self._embedding_fn(sample_ids))
    return (finished, next_inputs, state)
