from typing import Callable
from typing import List
from typing import Optional
from typing import Text
from typing import Tuple

import gin
import tensorflow as tf
from absl import logging
from tf_agents.agents.ppo import ppo_agent
from tf_agents.train import interval_trigger
from tf_agents.train import learner
from tf_agents.typing import types

# A function which processes a tuple of a nested tensor representing a TF-Agent
# Trajectory and Reverb SampleInfo.
_SequenceParamsType = Tuple[types.NestedTensor, types.ReverbSampleInfo]
_SequenceFnType = Callable[[_SequenceParamsType], _SequenceParamsType]


@gin.configurable(
    allowlist=[
        'checkpoint_interval',
        'summary_interval',
        'allow_variable_length_episodes',
    ]
)
class PPOLearner(object):
  """Manages all the learning details needed.

  These include:
    * Using distribution strategies correctly
    * Summaries
    * Checkpoints
    * Minimizing entering/exiting TF context:
        Especially in the case of TPUs scheduling a single TPU program to
        perform multiple train steps is critical for performance.
    * Generalizes the train call to be done correctly across CPU, GPU, or TPU
      executions managed by DistributionStrategies. This uses `strategy.run` and
      then makes sure to do a reduce operation over the `LossInfo` returned by
      the agent.
  """

  def __init__(
      self,
      root_dir: Text,
      train_step: tf.Variable,
      model_id: tf.Variable,
      agent: ppo_agent.PPOAgent,
      experience_datasets_fn: Callable[[], tf.data.Dataset],
      sequence_length: int,
      num_episodes_per_iteration: int,
      minibatch_size: int,
      shuffle_buffer_size: int,
      num_epochs: int,
      triggers: Optional[List[interval_trigger.IntervalTrigger]] = None,
      strategy: Optional[tf.distribute.Strategy] = None,
      per_sequence_fn: Optional[_SequenceFnType] = None,
      checkpoint_interval: int = 100000,
      summary_interval: int = 200,
      allow_variable_length_episodes: bool = False,
  ) -> None:
    """Initializes a CircuittrainingPPOLearner instance.

    Args:
      root_dir: Main directory path where checkpoints, saved_models, and
        summaries will be written to.
      train_step: a scalar tf.int64 `tf.Variable` which will keep track of the
        number of train steps. This is used for artifacts created like
        summaries, or outputs in the root_dir.
      model_id: a scalar tf.int64 `tf.Variable` which will keep track of the
        number of learner iterations / policy updates.
      agent: `ppo_agent.PPOAgent` instance to train with. Note that
        update_normalizers_in_train should be set to `False`, otherwise a
        ValueError will be raised. We do not update normalizers in the agent
        again because we already update it in the learner. When mini batching is
        enabled, compute_value_and_advantage_in_train should be set to False,
        and preprocessing should be done as part of the data pipeline as part of
        `replay_buffer.as_dataset`.
      experience_datasets_fn: a function that will create an instance of a
        tf.data.Dataset used to sample experience for training. Each element in
        the dataset is a (Trajectory, SampleInfo) pair.
      sequence_length: Fixed sequence length for elements in the dataset. Used
        for calculating how many iterations of minibatches to use for training.
      num_episodes_per_iteration: The number of episodes to sample for training.
        If fewer than this amount of episodes exists in the dataset, the learner
        will wait for more data to be added, or until the reverb timeout is
        reached.
      minibatch_size: The minibatch size. The dataset used for training is
        shaped `[minibatch_size, 1, ...]`. If None, full sequences will be fed
        into the agent. Please set this parameter to None for RNN networks which
        requires full sequences.
      shuffle_buffer_size: The buffer size for shuffling the trajectories before
        splitting them into mini batches. Only required when mini batch learning
        is enabled (minibatch_size is set). Otherwise it is ignored. Commonly
        set to a number 1-3x the episode length of your environment.
      num_epochs: The number of iterations to go through the same sequences.
      triggers: List of callables of the form `trigger(train_step)`. After every
        `run` call every trigger is called with the current `train_step` value
        as an np scalar.
      strategy: (Optional) `tf.distribute.Strategy` to use during training.
      per_sequence_fn: (Optional): sequence-wise preprecessing, pass in agent.
        preprocess for advantage calculation. This operation happens after
        take() and before rebatching.
      checkpoint_interval: Number of train steps in between checkpoints. Note
        these are placed into triggers and so a check to generate a checkpoint
        only occurs after every `run` call. Set to -1 to disable (this is not
        recommended, because it means that if the pipeline gets preempted, all
        previous progress is lost). This only takes care of the checkpointing
        the training process.  Policies must be explicitly exported through
        triggers.
      summary_interval: Number of train steps in between summaries. Note these
        are placed into triggers and so a check to generate a checkpoint only
        occurs after every `run` call.
      allow_variable_length_episodes: Whether to support variable length
        episodes for training.

    Raises:
      ValueError: agent._compute_value_and_advantage_in_train is set to `True`.
        preprocessing must be done as part of the data pipeline when mini
        batching is enabled.
    """

    self._strategy = strategy or tf.distribute.get_strategy()
    self._agent = agent
    self._minibatch_size = minibatch_size
    self._shuffle_buffer_size = shuffle_buffer_size
    self._num_epochs = num_epochs
    self._experience_datasets_fn = experience_datasets_fn
    self._num_episodes_per_iteration = num_episodes_per_iteration
    # Tracks the number of times learner.run() has been called.
    # This is used for filtering out data generated by older models to ensure
    # the on policyness of the algorithm.
    self._model_id = model_id
    self._sequence_length = sequence_length
    self._per_sequence_fn = per_sequence_fn

    self._generic_learner = learner.Learner(
        root_dir,
        train_step,
        agent,
        after_train_strategy_step_fn=None,
        triggers=triggers,
        checkpoint_interval=checkpoint_interval,
        summary_interval=summary_interval,
        use_kwargs_in_agent_train=False,
        strategy=self._strategy,
    )

    self.num_replicas = self._strategy.num_replicas_in_sync
    self._allow_variable_length_episodes = allow_variable_length_episodes
    self._num_samples = self._num_episodes_per_iteration * self._sequence_length
    self._create_datasets()
    self._steps_per_iter = self._get_train_steps_per_iteration()
    logging.info('train steps per iteration: %d', self._steps_per_iter)

  @property
  def train_summary_writer(self):
    return self._generic_learner.train_summary_writer

  def _create_datasets(self):
    """Create the training dataset and iterator."""

    def _filter_invalid_episodes(sample):
      sample_info = sample.info
      data_model_id = tf.cast(
          tf.reduce_min(sample_info.priority), dtype=tf.int64
      )

      if self._allow_variable_length_episodes:
        # Filter off policy samples.
        return tf.math.equal(self._model_id, data_model_id)
      else:
        # Filter infeasible placements with shorter episode lengths than
        # expected along with off policy samples.
        data = sample.data
        return tf.math.logical_and(
            tf.math.equal(tf.size(data.discount), self._sequence_length),
            tf.math.equal(self._model_id, data_model_id),
        )

    def make_dataset(_) -> tf.data.Dataset:
      # `experience_dataset_fn` returns a tf.Dataset. Each item is a (Trajectory
      # , SampleInfo) tuple, and the Trajectory represents one single episode
      # of a fixed sequence length. The Trajectory dimensions are [1, T, ...].
      train_datasets = self._experience_datasets_fn()
      processed_datasets = []
      for train_dataset in train_datasets:
        train_dataset = train_dataset.filter(_filter_invalid_episodes)
        if self._per_sequence_fn:
          train_dataset = train_dataset.map(
              self._per_sequence_fn,
              num_parallel_calls=tf.data.AUTOTUNE,
              deterministic=False,
          )

        # We unbatch the dataset shaped [B, T, ...] to a new dataset that
        # contains individual elements.
        # Note that we unbatch across the time dimension, which could result
        # in mini batches that contain subsets from more than one sequences.
        # PPO agent can handle mini batches across episode boundaries.
        train_dataset = train_dataset.unbatch()
        train_dataset = train_dataset.batch(1, drop_remainder=True)
        train_dataset = train_dataset.shuffle(self._shuffle_buffer_size)
        train_dataset = train_dataset.repeat(self._num_epochs)
        train_dataset = train_dataset.batch(
            self._minibatch_size, drop_remainder=True
        )

        processed_datasets += [train_dataset]

      all_dataset = tf.data.Dataset.sample_from_datasets(
          processed_datasets, stop_on_empty_dataset=False
      )
      options = tf.data.Options()
      options.deterministic = False
      options.experimental_optimization.parallel_batch = True
      all_dataset = all_dataset.with_options(options)

      return all_dataset

    with self._strategy.scope():
      if self._strategy.num_replicas_in_sync > 1:
        self._train_dataset = self._strategy.distribute_datasets_from_function(
            make_dataset
        )
      else:
        self._train_dataset = make_dataset(0)
      self._train_iterator = iter(self._train_dataset)

  def _get_train_steps_per_iteration(self):
    """Number of train steps each time learner.run() is called."""

    # We exhaust all num_episodes_per_iteration taken from Reverb in this setup.
    # Here we assume that there's only 1 episode per batch, and each episode is
    # of the fixed sequence length.
    num_mini_batches = int(
        self._num_samples * self._num_epochs / self._minibatch_size
    )
    train_steps = int(num_mini_batches / self.num_replicas)
    return train_steps

  def wait_for_data(self):
    """Blocking call on dataset."""
    sample_info = next(self._train_iterator)[1]
    logging.info('Sample priority: %s', sample_info.priority)

  def run(self):
    """Train `num_episodes_per_iteration` repeating for `num_epochs` of iterations.

    Returns:
      The total loss computed before running the final step.
    """
    loss_info = self._generic_learner.run(
        self._steps_per_iter, self._train_iterator
    )
    self._model_id.assign_add(1)
    return loss_info

  @property
  def train_step_numpy(self):
    """The current train_step.

    Returns:
      The current `train_step`. Note this will return a scalar numpy array which
      holds the `train_step` value when this was called.
    """
    return self._generic_learner.train_step_numpy
