# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""InteractiveModel class for use in T5X Colabs.

The InteractiveModel can be used to run training, inference, and evaluation on
natural text inputs and targets.

"""

import abc
from collections.abc import Mapping, Sequence
import enum
import functools
import inspect
import itertools
import logging
import os
import re
from typing import Any, Callable, Iterator, Optional, Tuple, Union

import clu.data.dataset_iterator
import jax
from jax import random
from jax.experimental import multihost_utils
import numpy as np
import seqio
from t5x import checkpoints
from t5x import models
from t5x import partitioning
from t5x import trainer as trainer_lib
from t5x import utils
from t5x.infer import _extract_tokens_and_aux_values
from t5x.infer import _Inferences
import tensorflow as tf
import tensorflow_datasets as tfds

BatchesType = Union[Sequence[Mapping[str, str]],
                    Sequence[Sequence[Mapping[str, str]]]]


class InferenceType(enum.Enum):
  PREDICT_WITH_AUX = 1
  SCORE = 2


class InteractiveModel(abc.ABC):
  """Wrapper around T5X components to enable interactive train/infer/eval."""

  def __init__(self,
               batch_size: int,
               task_feature_lengths: Mapping[str, int],
               output_dir: str,
               partitioner: partitioning.BasePartitioner,
               model: models.BaseTransformerModel,
               dtype: str,
               restore_mode: str,
               checkpoint_path: str,
               input_shapes: Mapping[str, utils.Array],
               input_types: Optional[Mapping[str, utils.DType]] = None,
               init_random_seed: int = 42,
               add_eos: bool = True):
    """Init function.

    Configures the output directory, RNGs, and partitioner given the provided
    arguments.

    Args:
      batch_size: number of examples per batch for training, inference, and
        evaluation.
      task_feature_lengths: dictionary mapping feature key to maximum length
        (int) for that feature. If feature is longer than this length after
        preprocessing, the feature will be truncated. May be set to None to
        avoid truncation.
      output_dir: Path to directory where we will write temporary files and
        final results.
      partitioner: the partitioner that defines how we divide and replicate
        machine learning model parameters, activations, and data across the
        accelerator devices (TPU/GPU). See https://github.com/google-research/t5x/blob/main/docs/usage.md/partitioning for
        details.
      model: the model object to use for training, inference, and evaluation.
      dtype: The dtype to restore ('float32' or 'bfloat16'), or None to load as
        saved.
      restore_mode: One of 'specific', 'latest', or 'all'. `specific` loads the
        checkpoint specified by `path`. `latest` loads the most recent
        checkpoint in the directory specified by `path`. `all` sequentially
        loads all of checkpoints in the directory `path`.
      checkpoint_path: Path(s) to checkpoint to restore from or directory
        (depending on `restore_mode`).
      input_shapes: a mapping from key to array shape for each feature in the
        global (unsharded) input batch.
      input_types: a mapping from key to array type for each feature in the
        global (unshared) input batch. If not provided, the type is assumed to
        be `jnp.float32`.
      init_random_seed: the random seed used to initialize all RNGs.
      add_eos: whether or not to add the EOS token to inputs/targets.

    Raises:
      ValueError: the partitioner has an incorrect submesh, or the checkpoint
        restore function returned a sequence of TrainStates, when it should have
        returned a single TrainState.
    """
    self._batch_size = batch_size
    self._task_feature_lengths = task_feature_lengths
    # --------------------------------------------------------------------------
    # Configure the output directory
    # --------------------------------------------------------------------------
    self._output_dir = output_dir
    # Remove double-slashes in directory path to avoid inconsistencies.
    self._output_dir = re.sub(r"(?<!gs:)([\/]{2,})", "/", self._output_dir)
    if not os.path.exists(self._output_dir):
      os.mkdir(self._output_dir)

    # --------------------------------------------------------------------------
    # Initialize RNGs
    # --------------------------------------------------------------------------
    random_seed = multihost_utils.broadcast_one_to_all(
        np.int32(init_random_seed))
    utils.set_hardware_rng_ops()

    rng = random.PRNGKey(random_seed)
    self._init_rng, self._trainer_rng = random.split(rng, 2)

    # --------------------------------------------------------------------------
    # Initialize the partitioner.
    # --------------------------------------------------------------------------
    if partitioner._model_parallel_submesh:
      num_partitions = np.prod(partitioner._model_parallel_submesh)
    else:
      num_partitions = partitioner._num_partitions
    if jax.device_count() % num_partitions != 0:
      raise ValueError(
          "The number of devices available must be a multiple of the number of",
          f" partitions. There are {jax.device_count()} devices available, but",
          f" the number of partitions is set to {num_partitions}. Please",
          " provide a different number of partitions.")
    self._partitioner = partitioner

    # --------------------------------------------------------------------------
    # Create and save a checkpoint manager.
    # --------------------------------------------------------------------------
    logging.info("Initializing model, optimizer, and step functions.")
    self._model = model
    self._feature_converter = self._model.FEATURE_CONVERTER_CLS(pack=False)
    self._input_shapes = input_shapes
    self._input_types = input_types
    # Save the model vocabulary as features.
    output_features = {
        "inputs":
            seqio.Feature(
                vocabulary=self._model.input_vocabulary, add_eos=add_eos),
        "targets":
            seqio.Feature(
                vocabulary=self._model.output_vocabulary, add_eos=add_eos)
    }
    self._features = dict(sorted(output_features.items()))

    # Define restore and save checkpoints.
    if checkpoint_path:
      self._restore_checkpoint_cfg = utils.RestoreCheckpointConfig(
          dtype=dtype, mode=restore_mode, path=checkpoint_path, use_gda=False)
    else:
      self._restore_checkpoint_cfg = None
    self._save_checkpoint_cfg = utils.SaveCheckpointConfig(
        dtype=dtype, keep=5, save_dataset=False, use_gda=False, period=1000)
    self._train_state_initializer = utils.TrainStateInitializer(
        optimizer_def=self._model.optimizer_def,
        init_fn=self._model.get_initial_variables,
        input_shapes=self._input_shapes,
        input_types=self._input_types,
        partitioner=self._partitioner)

    # Initialize checkpoint manager.
    self._checkpoint_manager = utils.LegacyCheckpointManager(
        save_cfg=self._save_checkpoint_cfg,
        restore_cfg=self._restore_checkpoint_cfg,
        train_state_shape=(
            self._train_state_initializer.global_train_state_shape),
        partitioner=self._partitioner,
        ds_iter=None,
        model_dir=self._output_dir,
        use_gda=False)

    # --------------------------------------------------------------------------
    # Restore a model from a checkpoint or from scratch.
    # --------------------------------------------------------------------------
    def get_state(rng):
      return self._train_state_initializer.from_scratch(rng).state_dict()

    restore_cfgs = []
    # 1. From a checkpoint specified by `self._restore_checkpoint_cfg.path`, if
    # set.
    if self._restore_checkpoint_cfg:
      restore_cfgs.append(self._restore_checkpoint_cfg)
    # 2. If no checkpoint provided, look for one in the model directory.
    if self._restore_checkpoint_cfg is not None:
      state_transforms_for_restore = [
          functools.partial(fn, is_resuming=True)
          for fn in self._restore_checkpoint_cfg.state_transformation_fns
      ]
    else:
      state_transforms_for_restore = []
    restore_cfgs.append(
        utils.RestoreCheckpointConfig(
            path=self._output_dir,
            mode="latest",
            dtype=self._save_checkpoint_cfg.dtype
            if self._save_checkpoint_cfg else "float32",
            checkpointer_cls=self._save_checkpoint_cfg.checkpointer_cls
            if self._save_checkpoint_cfg else checkpoints.Checkpointer,
            # Restore dataset state if it is being saved.
            restore_dataset=(self._save_checkpoint_cfg and
                             self._save_checkpoint_cfg.save_dataset),
            state_transformation_fns=state_transforms_for_restore))

    # Restore the model using a checkpoint.
    valid_restore_cfg, restore_paths = utils.get_first_valid_restore_config_and_paths(
        restore_cfgs)
    self._train_state = self._checkpoint_manager.restore(
        restore_paths, valid_restore_cfg,
        utils.get_fallback_state(valid_restore_cfg, get_state, self._init_rng))

    # 3. If no checkpoint to restore, init from scratch.
    if self._train_state is None:
      self._train_state = self._train_state_initializer.from_scratch(
          self._init_rng)
    self._train_state_axes = self._train_state_initializer.train_state_axes

    # Log the variable shapes information and write to a file.
    log_file = os.path.join(self._output_dir, "model-info.txt")
    utils.log_model_info(log_file,
                         self._train_state_initializer.global_train_state_shape,
                         self._partitioner)

    # --------------------------------------------------------------------------
    # Trainer
    # --------------------------------------------------------------------------
    if isinstance(self._train_state, Sequence):
      raise ValueError(
          "Expected a single train state, but instead received a Sequence.")
    self._trainer = trainer_lib.Trainer(
        model=self._model,
        train_state=self._train_state,
        partitioner=self._partitioner,
        eval_names=[],
        summary_dir=self._output_dir,
        train_state_axes=self._train_state_axes,
        rng=self._trainer_rng,
        learning_rate_fn=utils.create_learning_rate_scheduler(),
        num_microbatches=None)

  @property
  def train_summary(self):
    return self._train_summary.result()

  @property
  def step(self):
    if isinstance(self._train_state, Sequence):
      raise ValueError(
          "Expected a single train state, but instead received a Sequence.")
    return int(self._train_state.step)

  def train_step(self, examples: Sequence[Union[str, dict[str, str]]]):
    """Train function.

    Args:
      examples: a single batch of examples that should be transformed into a
        tf.data.Dataset. The examples can either take the form of a string (ex:
        a single input for inference), or a dictionary mapping "input"/"target"
        to a string containing that element.

    Raises:
      ValueError: the user provided less than `batch_size` examples, or
        `self._train_state` was set to a sequence of TrainStates, when it should
        have been a single TrainState.
    """
    # By default, only tokenize and append EOS.
    preprocessors = [
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
    ]
    self.train_step_with_preprocessors(
        examples=examples, preprocessors=preprocessors)

  def train_step_with_preprocessors(
      self, examples: Sequence[Union[str, dict[str, str]]],
      preprocessors: Sequence[Callable[..., tf.data.Dataset]]):
    """Train function.

    Args:
      examples: a single batch of examples that should be transformed into a
        tf.data.Dataset. The examples can either take the form of a string (ex:
        a single input for inference), or a dictionary mapping "input"/"target"
        to a string containing that element.
      preprocessors: list(callable), an optional list of functions that receive
        a tf.data.Dataset and return a tf.data.Dataset. These will be executed
        sequentially and the final dataset must include features matching
        `self._features`.

    Raises:
      ValueError: the user provided less than `batch_size` examples, or
        `self._train_state` was set to a sequence of TrainStates, when it should
        have been a single TrainState.
    """
    # --------------------------------------------------------------------------
    # Initialize dataset and dataset iterator
    # --------------------------------------------------------------------------
    if len(examples) < self._batch_size:
      raise ValueError(
          "At least one batch of data must be provided. Please decrease the "
          "batch_size or provide more examples.")

    train_dataset = get_dataset_from_natural_text_examples(
        examples,
        preprocessors=preprocessors,
        task_feature_lengths=self._task_feature_lengths,
        features=self._features)
    train_dataset = self._feature_converter(
        train_dataset, task_feature_lengths=self._task_feature_lengths)
    train_dataset = train_dataset.padded_batch(
        self._batch_size, drop_remainder=True)
    train_iter = clu.data.dataset_iterator.TfDatasetIterator(
        train_dataset, checkpoint=True)

    # --------------------------------------------------------------------------
    # Take 1 train step.
    # --------------------------------------------------------------------------
    # `stop_training` is requested, break out the main loop immediately.
    if self._trainer.stop_training:
      logging.info(
          "Stopping training early since `stop_training` is requested.")
      return

    if isinstance(self._train_state, Sequence):
      raise ValueError(
          "Expected a single train state, but instead received a Sequence.")
    try:
      first_step = int(utils.get_local_data(self._train_state.step))
      self._train_summary = self._trainer.train(
          train_iter, 1, start_step=first_step)
    except trainer_lib.PreemptionError as e:
      logging.info("Saving emergency checkpoint.")
      self._checkpoint_manager.save(
          self._trainer.train_state,
          self._save_checkpoint_cfg.state_transformation_fns)
      logging.info("Saving emergency checkpoint done.")
      raise e

    # Save a checkpoint.
    logging.info("Saving checkpoint.")
    self._checkpoint_manager.save(
        self._trainer.train_state,
        self._save_checkpoint_cfg.state_transformation_fns)

    # Wait until computations are done before exiting
    utils.sync_global_devices("complete")
    self._train_state = self._trainer.train_state

  def infer_with_preprocessors(
      self, mode: InferenceType, examples: Sequence[Union[str, dict[str, str]]],
      preprocessors: Sequence[Callable[..., tf.data.Dataset]]) -> _Inferences:
    """Infer function.

    Args:
      mode: Either 'score' to compute the log likelihood of given targets, or
        'predict_with_aux' to score and decode targets.
      examples: a single batch of examples that should be transformed into a
        tf.data.Dataset. The examples can either take the form of a string (ex:
        a single input for inference), or a dictionary mapping "input"/"target"
        to a string containing that element.
      preprocessors: list(callable), an optional list of functions that receive
        a tf.data.Dataset and return a tf.data.Dataset. These will be executed
        sequentially and the final dataset must include features matching
        `self._features`.

    Returns:
      Returns a tuple of predictions/scores and any auxiliary values.
    """
    # --------------------------------------------------------------------------
    # Parse Mode
    # --------------------------------------------------------------------------
    if mode == InferenceType.PREDICT_WITH_AUX:
      infer_step = self._model.predict_batch_with_aux
    elif mode == InferenceType.SCORE:
      infer_step = self._model.score_batch
    else:
      raise ValueError("Mode must be `predict_with_aux`, or `score`,"
                       f" but instead was {mode}.")
    infer_fn = functools.partial(
        utils.get_infer_fn(
            infer_step=infer_step,
            batch_size=self._batch_size,
            train_state_axes=self._train_state_initializer.train_state_axes,
            partitioner=self._partitioner),
        train_state=self._train_state)

    # --------------------------------------------------------------------------
    # Construct a dataset and dataset iterator.
    # --------------------------------------------------------------------------
    dataset = get_dataset_from_natural_text_examples(
        examples,
        preprocessors=preprocessors,
        task_feature_lengths=self._task_feature_lengths,
        features=self._features)
    model_dataset = self._feature_converter(
        dataset, task_feature_lengths=self._task_feature_lengths)
    # Zip task and model features.
    infer_dataset = tf.data.Dataset.zip((dataset, model_dataset))
    # Create batches and index them.
    infer_dataset = infer_dataset.padded_batch(
        self._batch_size, drop_remainder=False).enumerate()
    infer_dataset_iter: Iterator[Tuple[int, Any]] = iter(
        infer_dataset.prefetch(tf.data.experimental.AUTOTUNE))

    # --------------------------------------------------------------------------
    # Run inference
    # --------------------------------------------------------------------------
    # Main Loop over "batches".
    all_inferences = []
    all_aux_values = {}
    for chunk, chunk_batch in infer_dataset_iter:
      # Load the dataset for the next chunk. We can't use `infer_dataset_iter`
      # directly since `infer_fn` needs to know the exact size of each chunk,
      # which may be smaller for the final one.
      chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch)
      chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE)

      # Unzip chunk dataset in to pretokenized and model datasets.
      task_dataset = chunk_dataset.map(
          lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      model_dataset = chunk_dataset.map(
          lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE)

      # Get a chunk-specific RNG key.
      chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)

      inferences = _extract_tokens_and_aux_values(
          infer_fn(model_dataset.enumerate(), rng=chunk_rng))

      predictions, aux_values = inferences
      accumulated_inferences = []
      for idx, inputs in task_dataset.enumerate().as_numpy_iterator():
        prediction = predictions[idx]
        # Decode predictions if applicable.
        if mode == InferenceType.PREDICT_WITH_AUX:
          prediction = self._features["targets"].vocabulary.decode_tf(
              tf.constant(prediction)).numpy()
        accumulated_inferences.append((inputs, prediction))
      all_inferences += accumulated_inferences
      # Accumulate aux values over batches.
      if not all_aux_values:
        all_aux_values = aux_values
      else:
        for key, values in aux_values.items():
          all_aux_values[key] += values

    return all_inferences, all_aux_values

  def predict_with_aux(
      self, examples: Sequence[Union[str, dict[str, str]]]) -> _Inferences:
    """Predict with auxiliary values method."""
    # By default, only tokenize and append EOS.
    preprocessors = [
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
    ]
    return self.infer_with_preprocessors(
        mode=InferenceType.PREDICT_WITH_AUX,
        examples=examples,
        preprocessors=preprocessors)

  def score(self, examples: Sequence[Union[str, dict[str,
                                                     str]]]) -> Sequence[Any]:
    """Score method."""
    # By default, only tokenize and append EOS.
    preprocessors = [
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
    ]
    # Ignore auxiliary values.
    scores, _ = self.infer_with_preprocessors(
        mode=InferenceType.SCORE,
        examples=examples,
        preprocessors=preprocessors)
    return scores

  def _compute_metrics(
      self, targets: Sequence[Any], predictions: Sequence[Any],
      aux_values: Sequence[Any], scores: Sequence[Any],
      predict_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable],
      predict_with_aux_metric_fns: Sequence[
          seqio.dataset_providers.MetricFnCallable],
      score_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable]):
    """Computes the metrics specified in the metric_fns lists."""
    # Only compute metrics once
    if jax.process_index() != 0:
      return {}

    def compute_metrics_fn():
      task_metrics = []
      if predict_metric_fns:
        task_metrics.extend([
            metric_fn(targets, predictions) for metric_fn in predict_metric_fns
        ])
      if predict_with_aux_metric_fns:
        task_metrics.extend([
            metric_fn(targets, predictions, aux_values)
            for metric_fn in predict_with_aux_metric_fns
        ])
      if score_metric_fns:
        is_tuple = isinstance(scores, tuple)
        if ((not is_tuple and len(targets) != len(scores)) or
            (is_tuple and len(targets) != len(scores[0]))):
          raise ValueError(f"len(targets)({len(targets)}) != "
                           f"len(output_scores)({len(scores)})")
        task_metrics.extend(
            [metric_fn(targets, scores) for metric_fn in score_metric_fns])

      all_metrics = {}
      for k, v in itertools.chain(*[m.items() for m in task_metrics]):
        if k in all_metrics:
          raise ValueError(f"Duplicate metric key '{k}' in Task.")
        all_metrics[k] = v
      return all_metrics

    if not tf.executing_eagerly():

      def wrap_graph(fn):
        graph = tf.compat.v1.get_default_graph()

        def wrapped_fn():
          with graph.as_default():
            return fn()

        return wrapped_fn

      compute_metrics_fn = wrap_graph(compute_metrics_fn)

    all_metrics = compute_metrics_fn()
    # Wait until computations are done before continuing.
    utils.sync_global_devices("Completed.")
    return all_metrics

  def evaluate(
      self,
      examples: Sequence[Union[str, dict[str, str]]],
      metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable],
  ) -> Mapping[Any, Any]:
    """Evaluation function.

    Args:
      examples: a single batch of examples that should be transformed into a
        tf.data.Dataset. The examples can either take the form of a string (ex:
        a single input for inference), or a dictionary mapping "input"/"target"
        to a string containing that element.
      metric_fns: list(callable), an optional list of metric functions with a
        signature that matches one of three possible forms: - (targets, scores)
        - Note that `scores` refers to the score the model assigned the target
        sequence, given the input. - (targets, predictions) - (targets,
        predictions, aux_values) - Note that `aux_values` refers to a dictionary
        of auxiliary values that the model assigned to each sequence.

    Returns:
      Mapping of metrics names to metrics values.
    """
    # By default, only tokenize and append EOS.
    preprocessors = [
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
    ]
    return self.evaluate_with_preprocessors(
        examples=examples,
        preprocessors=preprocessors,
        metric_fns=metric_fns,
        postprocessor=None)

  def evaluate_with_preprocessors(
      self,
      examples: Sequence[dict[str, str]],
      preprocessors: Sequence[Callable[..., tf.data.Dataset]],
      metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable],
      postprocessor: Optional[Callable[..., Any]] = None,
  ) -> Mapping[Any, Any]:
    """Evaluation function.

    Args:
      examples: a single batch of examples that should be transformed into a
        tf.data.Dataset. The examples must take the form of a dictionary mapping
        "input"/"target" to a string containing that element.
      preprocessors: list(callable), an optional list of functions that receive
        a tf.data.Dataset and return a tf.data.Dataset. These will be executed
        sequentially and the final dataset must include features matching
        `self._features`.
      metric_fns: list(callable), an optional list of metric functions with a
        signature that matches one of three possible forms: - (targets, scores)
        - Note that `scores` refers to the score the model assigned the target
        sequence, given the input. - (targets, predictions) - (targets,
        predictions, aux_values) - Note that `aux_values` refers to a dictionary
        of auxiliary values that the model assigned to each sequence.
      postprocessor: callable, an optional function that receives decoded model
        outputs and converts them to a form that is ready for evaluation using
        the metric functions in `metric_fns`.

    Returns:
      Mapping of metrics names to metrics values.
    """
    # --------------------------------------------------------------------------
    # Parse Metrics functions
    # --------------------------------------------------------------------------
    predict_metric_fns = []
    predict_with_aux_metric_fns = []
    score_metric_fns = []
    for metric_fn in metric_fns:
      pos_args = tuple(
          key for key, param in inspect.signature(metric_fn).parameters.items()
          if param.default == inspect.Parameter.empty)
      if pos_args == ("targets", "scores"):
        score_metric_fns.append(metric_fn)
      elif pos_args == ("targets", "predictions"):
        predict_metric_fns.append(metric_fn)
      elif pos_args == ("targets", "predictions", "aux_values"):
        predict_with_aux_metric_fns.append(metric_fn)
      else:
        raise ValueError(
            "Metric functions must have positional arguments matching either "
            "('targets', 'scores'), ('targets', 'predictions') or "
            "('targets', 'predictions', 'aux_values'). "
            f"Got: {pos_args}")

    # ------------------------------------------------------------------------
    # Get targets, predictions, and scores
    # ------------------------------------------------------------------------
    dataset = get_dataset_from_natural_text_examples(
        examples,
        preprocessors=preprocessors,
        task_feature_lengths=self._task_feature_lengths,
        features=self._features)

    # Get targets.
    def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any:
      """Returns the model output after applying the postprocess function."""
      if postprocessor:
        return postprocessor(decoded_model_output, **postprocess_kwargs)
      return decoded_model_output

    targets = []
    for ex in tfds.as_numpy(dataset):
      targets.append(
          postprocess_fn(
              decoded_model_output=ex["targets_pretokenized"],
              example=ex,
              is_target=True))

    # Get predictions.
    predictions = []
    if predict_with_aux_metric_fns or predict_metric_fns:
      predictions, aux_values = self.infer_with_preprocessors(
          mode=InferenceType.PREDICT_WITH_AUX,
          examples=examples,
          preprocessors=preprocessors)
      predictions = [
          prediction.decode("utf-8") for example, prediction in predictions
      ]
    # Get scores.
    scores = []
    if score_metric_fns:
      scores, _ = self.infer_with_preprocessors(
          mode=InferenceType.SCORE,
          examples=examples,
          preprocessors=preprocessors)
      scores = [score for example, score in scores]

    return self._compute_metrics(
        targets,
        predictions,
        aux_values,
        scores,  # pytype: disable=wrong-arg-types  # mapping-is-not-sequence
        predict_metric_fns,
        predict_with_aux_metric_fns,
        score_metric_fns)

  def train_loop(
      self,
      num_steps: int,
      eval_period: Optional[int] = 1,
      train_batches: Optional[BatchesType] = None,
      predict_batches: Optional[BatchesType] = None,
      score_batches: Optional[BatchesType] = None,
      eval_batches: Optional[BatchesType] = None,
      metrics_fns: Optional[Sequence[
          seqio.dataset_providers.MetricFnCallable]] = None,
  ):
    """Runs training, inference, and evaluation for `num_steps`.

    It should be noted that there are many different possible variants of the
    `train_loop` function that a user might want to use. The primary goal of the
    `train_loop` function is not to cover all the potential training loop
    variants that a user may want; rather, the goal is to demonstrate how the
    user could stack the `InteractiveModel` train, predict, score, and evaluate
    methods.

    Args:
      num_steps: the number of steps to run for training, inference, and
        evaluation.
      eval_period: specifies how many steps to take between
        inference/evaluation.
      train_batches: an optional list of batches that we should run training on.
        If no batches are provided, then training will be skipped. If a single
        batch is provided, we will repeat training on this batch for
        `num_steps`.
      predict_batches: an optional list of batches that we should get
        predictions for. If no batches are provided, then predicting will be
        skipped. If a single batch is provided, we will repeatedly get
        predictions on this batch for `num_steps`.
      score_batches: an optional list of batches that we should score. If no
        batches are provided, then scoring will be skipped. If a single batch is
        provided, we will repeatedly score this batch for `num_steps`.
      eval_batches: an optional list of batches that we should run eval on. If
        no batches are provided, then evaluation will be skipped. If a single
        batch is provided, we will repeatedly evaluate this batch for
        `num_steps`.
      metrics_fns: list(callable), an optional list of metric functions with a
        signature that matches one of three possible forms: - (targets, scores)
        - Note that `scores` refers to the score the model assigned the target
        sequence, given the input. - (targets, predictions) - (targets,
        predictions, aux_values) - Note that `aux_values` refers to a dictionary
        of auxiliary values that the model assigned to each sequence.

    Returns:
      Predictions, scores, and metrics for the final step of the training loop.
    """
    # Ensure all batches are `num_steps` in length
    train_batches = _get_equal_length_batches(train_batches, num_steps)

    predictions = None
    scores = None
    metrics = None
    for step_num, train_batch in enumerate(train_batches):
      if train_batch:
        self.train_step(train_batch)
      # Run inference/evaluation every `eval_period` steps.
      if step_num % eval_period == 0:
        # Run on all batches for inference/evaluation.
        if predict_batches:
          for predict_batch in predict_batches:
            predictions, _ = self.predict_with_aux(predict_batch)  # pytype: disable=wrong-arg-types  # mapping-is-not-sequence
        if score_batches:
          for score_batch in score_batches:
            scores = self.score(score_batch)  # pytype: disable=wrong-arg-types  # mapping-is-not-sequence
        if eval_batches:
          for eval_batch in eval_batches:
            metrics = self.evaluate(eval_batch, metrics_fns)  # pytype: disable=wrong-arg-types  # mapping-is-not-sequence
    return predictions, scores, metrics


def get_dataset_from_natural_text_examples(
    examples: Sequence[Union[str, dict[str, str]]],
    preprocessors: Sequence[Callable[..., tf.data.Dataset]],
    task_feature_lengths: Mapping[str, int],
    features: Mapping[str, Any]) -> tf.data.Dataset:
  """Returns a tf.data.Dataset from a list of examples.

  Args:
    examples: a single batch of examples that should be transformed into a
      tf.data.Dataset. The examples can either take the form of a string (ex: a
      single input for inference), or a dictionary mapping "input"/"target" to a
      string containing that element.
    preprocessors: an optional list of functions that receive a tf.data.Dataset
      and return a tf.data.Dataset. These will be executed sequentially and the
      final dataset must include features matching `self._features`.
    task_feature_lengths: dictionary mapping feature key to maximum length (int)
      for that feature. If feature is longer than this length after
      preprocessing, the feature will be truncated. May be set to None to avoid
      truncation.
    features: dictionary defining what features should be present in all
      examples.

  Returns:
    A tf.data.Dataset.
  """
  # ------------------------------------------------------------------------
  # Construct a `tf.data.Dataset` from the provided examples
  # ------------------------------------------------------------------------
  merged_examples = {"inputs": [], "targets": [], "answers": []}
  for example in examples:
    # If the provided example is just a string, add an empty target string
    if isinstance(example, dict):
      example_dict = example
    else:
      example_dict = {"input": example, "target": ""}
    merged_examples["inputs"].append(example_dict["input"])
    merged_examples["targets"].append(example_dict["target"])
    # Answers is a list of possible targets (in this case a single target).
    merged_examples["answers"].append([example_dict["target"]])
  dataset = tf.data.Dataset.from_tensor_slices(merged_examples)

  # Define `ShardInfo` that doesn't shard the data pipeline.
  shard_info = seqio.ShardInfo(0, 1)
  dataset = dataset.shard(shard_info.num_shards, shard_info.index)

  # ------------------------------------------------------------------------
  # Preprocess data
  # ------------------------------------------------------------------------
  for prep_fn in preprocessors:
    # prep_fn must not rely on variable length keyword args such as **kwargs.
    fn_args = set(inspect.signature(prep_fn).parameters.keys())
    kwargs = {}
    if "sequence_length" in fn_args:
      kwargs["sequence_length"] = task_feature_lengths
    if "output_features" in fn_args:
      kwargs["output_features"] = features
    dataset = prep_fn(dataset, **kwargs)

  def _validate_preprocessing(dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Validates preprocessed dataset, raising Exceptions if needed.

    Args:
      dataset: a tf.data.Dataset to validate.

    Returns:
      a validated tf.data.Dataset.

    Raises:
      ValueError: dataset has missing feature or the incorrect type/rank for a
        feature.
    """
    actual_specs = dataset.element_spec
    for feat, feat_spec in features.items():
      if feat not in actual_specs:
        if feat_spec.required:
          raise ValueError(
              "Task dataset is missing expected output feature after "
              f"preprocessing: {feat}")
        else:
          # It's ok that this feature does not exist.
          continue
      actual_spec = actual_specs[feat]
      if feat_spec.dtype != actual_spec.dtype:
        raise ValueError(
            f"Task dataset has incorrect type for feature '{feat}' after "
            f"preprocessing: Got {actual_spec.dtype.name}, expected "
            f"{feat_spec.dtype.name}")
      if feat_spec.rank != actual_spec.shape.rank:
        raise ValueError(
            f"Task dataset has incorrect rank for feature '{feat}' after "
            f"preprocessing: Got {actual_spec.shape.rank}, expected "
            f"{feat_spec.rank}")

    return dataset

  dataset = _validate_preprocessing(dataset)
  dataset = seqio.utils.trim_dataset(dataset, task_feature_lengths, features)
  return dataset.prefetch(tf.data.experimental.AUTOTUNE)


def _get_equal_length_batches(batches: BatchesType,
                              length: int) -> Sequence[Any]:
  """Produces a list of batches that is `length` batches long.

  Given a single batch, repeat the batch `length` times.

  Given a list of batches, either repeat the batches to get `length` total
  batches or take the first 'length' batches.

  Args:
    batches: either a single batch of examples, or a list of batches.
    length: the total number of batches that should be present in the final
      list.

  Returns:
    A list of batches.
  """
  # Given a list of batches, return a list of batches that is `length` long,
  # either by repeating the batches or taking the first `length` batches
  if not batches:
    return [None] * length
  if isinstance(batches[0], Mapping):
    return [batches for i in range(length)]
  if len(batches) < length:
    batches = batches * (length // len(batches))
  # If multiple batches are provided, only use the first `length` batches.
  logging.warning(
      "We will only use the first %s batches provided for training.", length)
  return batches[:length]


def get_batches_from_seqio(
    task_or_mixture_name: str,
    split: str,
    batch_size: int,
    num_batches: int,
    get_pretokenized_examples: bool = True,
    sequence_length: Optional[Mapping[str, int]] = None,
    **get_dataset_kwargs) -> Sequence[Sequence[Mapping[str, str]]]:
  """Returns a batch of examples from a provided SeqIO task.

  Args:
    task_or_mixture_name: the SeqIO task/mixture to read data from.
    split: the split of the SeqIO task/mixture to read data from.
    batch_size: how many examples should be in each batch.
    num_batches: the total number of batches to return.
    get_pretokenized_examples: a bool, where True indicates that we should
      return the natural text (pre-tokenization) inputs and targets. Default to
      True in order to make the examples easy to debug/inspect.
    sequence_length: dictionary mapping feature key to maximum length (int) for
      that feature. Used by SeqIO to get the dataset.
    **get_dataset_kwargs: any additional arguments that should be passed to the
      SeqIO `get_dataset()` call.

  Returns:
    A sequence of batches, where each batch is a sequnce of examples. Each
      example is a dictionary mapping 'input' and 'target' to the corresponding
      values for a single example.
  """
  task_or_mixture = seqio.get_mixture_or_task(task_or_mixture_name)
  total_examples_requested = batch_size * num_batches
  dataset = task_or_mixture.get_dataset(
      sequence_length=sequence_length, split=split, **get_dataset_kwargs)

  all_batches = []
  current_batch = []
  input_key = "inputs_pretokenized" if get_pretokenized_examples else "inputs"
  target_key = "targets_pretokenized" if get_pretokenized_examples else "targets"
  total_examples_seen = 0
  # It should be noted that we could replace the following loop with tf.Dataset
  # operations (like
  # `list(dataset.batch(batch_size).take(num_batches).as_numpy_iterator())`),
  # but this would require us to pad batches first or represent the token IDs as
  # ragged tensors. These approaches are currently overkill for the
  # InteractiveModel, but may be investigated in the future.
  dataset = dataset.take(total_examples_requested)
  for idx, element in enumerate(dataset.as_numpy_iterator()):
    total_examples_seen += 1
    if idx >= total_examples_requested:
      # Because we force `num_examples_requested` to be a multiple of
      # `batch_size`, this should enforce that the last batch always has the
      # same number of examples as all other batches.
      break

    example_input = element[input_key]
    example_target = element[target_key]
    if not get_pretokenized_examples:
      example_input = example_input.tolist()
      example_target = example_target.tolist()
    current_example = {"input": example_input, "target": example_target}
    current_batch.append(current_example)

    # If we've collected `batch_size` examples, save the current batch and start
    # a new batch.
    if len(current_batch) == batch_size:
      all_batches.append(current_batch)
      current_batch = []

  if total_examples_seen < total_examples_requested:
    raise ValueError("Not enough examples in Task/Mixture. User requested "
                     f"{num_batches} batches of size {batch_size} for a total "
                     f"of {total_examples_requested} examples. Only "
                     f"{total_examples_seen} available in "
                     "Task/Mixture.")

  return all_batches


def get_seqio_task_from_examples(
    task_name: str,
    interactive_model: InteractiveModel,
    examples: Sequence[Union[str, dict[str, str]]],
    preprocessors: Sequence[Callable[..., tf.data.Dataset]],
    metric_fns: Optional[Sequence[
        seqio.dataset_providers.MetricFnCallable]] = None,
    add_to_registry: bool = True) -> seqio.Task:
  """Registers and returns a SeqIO task from the provided inputs.

  This function will be used to graduate people to the T5X/SeqIO-based
  train/infer/eval scripts.

  Args:
    task_name: the name of the SeqIO task to be created and registered.
    interactive_model: an instance of the InteractiveModel.
    examples: a single batch of examples that should be transformed into a
      tf.data.Dataset. The examples can either take the form of a string (ex: a
      single input for inference), or a dictionary mapping "input"/"target" to a
      string containing that element.
    preprocessors: an optional list of functions that receive a tf.data.Dataset
      and return a tf.data.Dataset. These will be executed sequentially and the
      final dataset must include features matching `self._features`.
    metric_fns: list(callable), an optional list of metric functions with a
      signature that matches one of three possible forms: - (targets, scores) -
      Note that `scores` refers to the score the model assigned the target
      sequence, given the input. - (targets, predictions) - (targets,
      predictions, aux_values) - Note that `aux_values` refers to a dictionary
      of auxiliary values that the model assigned to each sequence.
    add_to_registry: if True, will register the new task.

  Returns:
    A SeqIO task.
  """

  def dataset_fn(split, shuffle_files):
    del split, shuffle_files
    return get_dataset_from_natural_text_examples(
        examples,
        preprocessors=[],
        task_feature_lengths=interactive_model._task_feature_lengths,  # pylint: disable=protected-access
        features={})

  data_source = seqio.FunctionDataSource(
      dataset_fn=dataset_fn, splits=["train", "validation"])

  if add_to_registry:
    seqio.TaskRegistry.add(
        task_name,
        data_source,
        preprocessors=preprocessors,
        output_features=interactive_model._features,  # pylint: disable=protected-access
        metric_fns=metric_fns)

  return seqio.get_mixture_or_task(task_name)
