# coding=utf-8
# Copyright 2022 The Meta-Dataset Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Forming the first part of a tf.data pipeline, reading from a source on disk.

The data output by the Reader consists in episodes or batches (for EpisodeReader
and BatchReader respectively) from one source (one split of a dataset). They
contain strings represented images that have not been decoded yet, and can
contain placeholder examples and examples to discard.
See data/pipeline.py for the next stage of the pipeline.
"""
# TODO(lamblinp): Update variable names to be more consistent
# - target, class_idx, label

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import itertools
import os

from meta_dataset import data
import numpy as np
from six.moves import range
import tensorflow.compat.v1 as tf

# PLACEHOLDER_CLASS_ID will be used as the target of placeholder examples, that
# are used for padding only.
PLACEHOLDER_CLASS_ID = -1


def _pad(dataset_indices, chunk_size, placeholder_dataset_id):
  """Pads `dataset_indices` with placeholders so it has length `chunk_size`.

  Args:
    dataset_indices: list of (dataset_id, num_repeats) tuples representing a
      sequence of dataset IDs.
    chunk_size: int, size to pad to.
    placeholder_dataset_id: int, placeholder value to pad with.
  """
  pad_size = chunk_size - sum(n for i, n in dataset_indices)
  assert pad_size >= 0
  dataset_indices.append([placeholder_dataset_id, pad_size])


def episode_representation_generator(dataset_spec, split, pool, sampler):
  """Generates a stream of compact episode representations.

  Each episode is chunked into:

  * a "flush" chunk, which is meant to allow to flush examples, in case we are
    at the end of an epoch for one or more class in the episode (we want to
    avoid accidentally repeating an example due to epoch boundaries), and
  * some number of additional chunks (for example, a "support" chunk and a
    "query" chunk).

  To make sure the input pipeline knows where the episode boundary is within the
  stream (and where the boundary is between chunks in an episode), we enforce
  that each chunk has a fixed size by padding with placeholder dataset IDs (of
  value `num_classes`) as needed (in some cases it's possible that no padding is
  ever needed). The size of each chunk is prescribed by the
  `compute_chunk_sizes` method of `sampler`, which also implicitly defines the
  number of additional chunks (i.e. `len(chunk_sizes) - 1`).

  Instead of explicitly representing all elements of the dataset ID stream, this
  generator returns a compact representation where repeated elements are
  replaced with a `(dataset_id, num_repeats)` tuple.

  This generator is meant to be used with
  `tf.data.experimental.choose_from_datasets` and assumes that the list of
  tf.data.Dataset objects corresponding to each class in the dataset (there are
  `num_classes` of them, which is determined by inspecting the `dataset_spec`
  argument using the `split` argument) is appended with a placeholder Dataset
  (which has index `num_classes` in the list) which outputs a constant `(b'',
  PLACEHOLDER_CLASS_ID)` tuple).

  Note that a dataset ID is different from the (absolute) class ID: the dataset
  ID refers to the index of the Dataset in the list of Dataset objects, and the
  class ID (or label) refers to the second element of the tuple that the Dataset
  outputs.

  Args:
    dataset_spec: DatasetSpecification, dataset specification.
    split: one of Split.TRAIN, Split.VALID, or Split.TEST.
    pool: A string ('train' or 'test') or None, indicating which example-level
      split to select, if the current dataset has them.
    sampler: EpisodeDescriptionSampler instance.

  Yields:
    episode_representation: tensor of shape [N, 2], where N varies dynamically
      between episodes.
  """
  chunk_sizes = sampler.compute_chunk_sizes()
  # An episode always starts with a "flush" chunk to allow flushing examples at
  # class epoch boundaries, and contains `len(chunk_sizes) - 1` additional
  # chunks.
  flush_chunk_size, other_chunk_sizes = chunk_sizes[0], chunk_sizes[1:]

  class_set = dataset_spec.get_classes(split)
  num_classes = len(class_set)
  placeholder_dataset_id = num_classes

  total_images_per_class = dict(
      (class_idx,
       dataset_spec.get_total_images_per_class(class_set[class_idx], pool))
      for class_idx in range(num_classes))
  cursors = [0] * num_classes

  run_counter = 0
  # Infinite loop over episodes.
  while True:
    flushed_dataset_indices = []
    selected_dataset_indices = [[] for _ in other_chunk_sizes]
    if run_counter == 0:
      # Sample an episode description. A description is a tuple of
      # `(class_idx, ...)` tuples, where `class_idx` indicates the class to
      # sample from and the remaining `len(chunk_sizes) - 1` elements indicate
      # how many examples to allocate to each chunk.
      episode_description = sampler.sample_episode_description()

    if run_counter < sampler.episode_description_switch_frequency - 1:
      run_counter += 1
    else:
      run_counter = 0

    for element in episode_description:
      class_idx, distribution = element[0], element[1:]
      total_requested = sum(distribution)
      if total_requested > total_images_per_class[class_idx]:
        raise ValueError("Requesting more images than what's available for the "
                         'whole class')
      # If the total number of requested examples is greater than the number of
      # examples remaining for the current pass over class `class_idx`, we flush
      # the remaining examples and start a new pass over class `class_idx`.
      # TODO(lamblinp): factor this out into its own tracker class for
      # readability and testability.
      remaining = total_images_per_class[class_idx] - cursors[class_idx]
      if total_requested > remaining:
        flushed_dataset_indices.append([class_idx, remaining])
        cursors[class_idx] = 0
      # Elements of `distribution` correspond to how many examples of class
      # `class_idx` to allocate for each chunk (e.g. in a few-shot learning
      # context `distribution = [5, 8]` would allocate 5 examples to the
      # "support" chunk and 8 examples to the "query" chunk). Elements of
      # `selected_dataset_indices` correspond to the list of dataset indices
      # that have so far been requested for each chunk.
      for num_to_allocate, dataset_indices in zip(distribution,
                                                  selected_dataset_indices):
        dataset_indices.append([class_idx, num_to_allocate])
      cursors[class_idx] += total_requested

    # An episode sequence is generated in multiple phases, each padded with an
    # agreed-upon number of placeholder dataset IDs.

    _pad(flushed_dataset_indices, flush_chunk_size, placeholder_dataset_id)
    for dataset_indices, chunk_size in zip(selected_dataset_indices,
                                           other_chunk_sizes):
      _pad(dataset_indices, chunk_size, placeholder_dataset_id)

    episode_representation = np.array(
        list(
            itertools.chain(flushed_dataset_indices,
                            *selected_dataset_indices)),
        dtype='int64')
    yield episode_representation


def decompress_episode_representation(episode_representation):
  """Decompresses an episode representation into a dataset ID stream.

  Args:
    episode_representation: tensor of shape [None, 2]. Its first column
      represents dataset IDs and its second column represents the number of
      times they're repeated in the sequence.

  Returns:
    1D tensor, decompressed sequence of dataset IDs.
  """
  episode_representation.set_shape([None, 2])
  dataset_ids, repeats = tf.unstack(episode_representation, axis=1)
  return tf.repeat(dataset_ids, repeats)


class Reader(object):
  """Class reading data from one source and assembling examples.

  Specifically, it holds part of a tf.data pipeline (the source-specific part),
  that reads data from TFRecords and assembles examples from them.
  """

  def __init__(self,
               dataset_spec,
               split,
               shuffle_buffer_size,
               read_buffer_size_bytes,
               num_prefetch,
               num_to_take=-1,
               num_unique_descriptions=0):
    """Initializes a Reader from a source.

    The source is identified by dataset_spec and split.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: A learning_spec.Split object identifying the source split.
      shuffle_buffer_size: An integer, the shuffle buffer size for each Dataset
        object. If 0, no shuffling operation will happen.
      read_buffer_size_bytes: int or None, buffer size for each TFRecordDataset.
      num_prefetch: int, the number of examples to prefetch for each class of
        each dataset. Prefetching occurs just after the class-specific Dataset
        object is constructed. If < 1, no prefetching occurs.
      num_to_take: Optional, an int specifying a number of elements to pick from
        each tfrecord. If specified, the available images of each class will be
        restricted to that int. By default (-1) no restriction is applied and
        all data is used.
      num_unique_descriptions: An integer, the number of unique episode
        descriptions to use. If set to x > 0, x episode descriptions are
        pre-generated, and repeatedly iterated over. This is especially helpful
        when running on TPUs as it avoids the use of
        tf.data.Dataset.from_generator. If set to x = 0, no such upper bound on
        number of unique episode descriptions is set.
    """
    self.dataset_spec = dataset_spec
    self.split = split
    self.shuffle_buffer_size = shuffle_buffer_size
    self.read_buffer_size_bytes = read_buffer_size_bytes
    self.num_prefetch = num_prefetch
    self.num_to_take = num_to_take
    self.num_unique_descriptions = num_unique_descriptions

    self.base_path = self.dataset_spec.path
    self.class_set = self.dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)

  def construct_class_datasets(self,
                               pool=None,
                               repeat=True,
                               shuffle=True,
                               shuffle_seed=None):
    """Constructs the list of class datasets.

    Args:
      pool: A string (optional) indicating whether to only read examples from a
        given example-level split.
      repeat: Boolean indicating whether each of the class datasets should be
        repeated (to provide an infinite stream) or not.
      shuffle: Boolean indicating whether each of the class datasets should be
        shuffled or not.
      shuffle_seed: Optional, an int containing the seed passed to
        tf.data.Dataset.shuffle.

    Returns:
      class_datasets: list of tf.data.Dataset, one for each class.
    """
    file_pattern = self.dataset_spec.file_pattern
    # We construct one dataset object per class. Each dataset outputs a stream
    # of `(example_string, dataset_id)` tuples.
    class_datasets = []
    for dataset_id in range(self.num_classes):
      class_id = self.class_set[dataset_id]
      if pool:
        if not data.POOL_SUPPORTED:
          raise NotImplementedError(
              'Example-level splits or pools not supported.')
      else:
        if file_pattern.startswith('{}_{}'):
          # TODO(lamblinp): Add support for sharded files if needed.
          raise NotImplementedError('Sharded files are not supported yet. '
                                    'The code expects one dataset per class.')
        elif file_pattern.startswith('{}'):
          filename = os.path.join(self.base_path, file_pattern.format(class_id))
        else:
          raise ValueError('Unsupported file_pattern in DatasetSpec: %s. '
                           'Expected something starting with "{}" or "{}_{}".' %
                           file_pattern)

      example_string_dataset = tf.data.TFRecordDataset(
          filename, buffer_size=self.read_buffer_size_bytes)

      # Create a dataset containing only num_to_take elements from
      # example_string_dataset. By default, takes all elements.
      example_string_dataset = example_string_dataset.take(self.num_to_take)

      if self.num_prefetch > 0:
        example_string_dataset = example_string_dataset.prefetch(
            self.num_prefetch)
      if shuffle:
        # Do not set a buffer size greater than the number of examples in this
        # class, as it can result in unnecessary memory being allocated.
        num_examples = self.dataset_spec.get_total_images_per_class(
            class_id, pool=pool)
        shuffle_buffer_size = min(num_examples, self.shuffle_buffer_size)
        if shuffle_buffer_size > 1:
          example_string_dataset = example_string_dataset.shuffle(
              buffer_size=shuffle_buffer_size,
              seed=shuffle_seed,
              reshuffle_each_iteration=True)
      if repeat:
        example_string_dataset = example_string_dataset.repeat()

      # These are absolute, dataset-specific class IDs (not relative to a given
      # split). It is okay to have class ID collisions across datasets, since we
      # don't sample multi-dataset episodes.
      class_id_dataset = tf.data.Dataset.from_tensors(class_id).repeat()
      dataset = tf.data.Dataset.zip((example_string_dataset, class_id_dataset))
      class_datasets.append(dataset)

    assert len(class_datasets) == self.num_classes
    return class_datasets


class EpisodeReaderMixin(object):
  """Mixin class to assemble examples as episodes."""

  def create_dataset_input_pipeline(self,
                                    sampler,
                                    pool=None,
                                    shuffle_seed=None):
    """Creates a Dataset encapsulating the input pipeline for one data source.

    Args:
      sampler: EpisodeDescriptionSampler instance.
      pool: A string (optional) indicating whether to only read examples from a
        given example-level split.
      shuffle_seed: Optional, an int containing the seed passed to
        tf.data.Dataset.shuffle.

    Returns:
      dataset: a tf.data.Dataset instance which encapsulates episode creation
        for the data identified by `dataset_spec` and `split`. These episodes
        contain flushed examples and are internally padded with placeholders.
        A later part of the pipeline, shared across all sources, will extract
        support and query sets and decode the example strings.
    """
    # Always shuffle, unless self.shuffle_buffer_size is 0
    shuffle = (self.shuffle_buffer_size and self.shuffle_buffer_size > 0)
    class_datasets = self.construct_class_datasets(
        pool=pool, shuffle=shuffle, shuffle_seed=shuffle_seed)

    # We also construct a placeholder dataset which outputs
    # `(b'', PLACEHOLDER_CLASS_ID)` tuples.
    placeholder_dataset = tf.data.Dataset.zip(
        (tf.data.Dataset.from_tensors(b'').repeat(),
         tf.data.Dataset.from_tensors(PLACEHOLDER_CLASS_ID).repeat()))
    class_datasets.append(placeholder_dataset)

    # The "choice" dataset outputs a stream of dataset IDs which are used to
    # select which class dataset to sample from. We turn the stream of dataset
    # IDs into a stream of `(example_string, class_id)` tuples using
    # `choose_from_datasets`.
    representation_generator = functools.partial(
        episode_representation_generator,
        dataset_spec=self.dataset_spec,
        split=self.split,
        pool=pool,
        sampler=sampler)

    if not self.num_unique_descriptions:
      choice_dataset = tf.data.Dataset.from_generator(representation_generator,
                                                      (tf.int64),
                                                      tf.TensorShape([None, 2]))
    else:
      # If num_unique_descriptions is x > 0, then we pre-generate x number of
      # episodes and repeatedly iterate over them.
      representations = list(
          map(
              # We need to use an intermediate string representation in order to
              # shuffle ragged arrays with tf.data.Dataset.
              tf.io.serialize_tensor,
              itertools.islice(representation_generator(),
                               self.num_unique_descriptions)))
      choice_dataset = tf.data.Dataset.from_tensor_slices(
          representations).shuffle(self.num_unique_descriptions).map(
              lambda s: tf.io.parse_tensor(s, tf.int64)).repeat()

    choice_dataset = choice_dataset.map(
        decompress_episode_representation).unbatch()

    dataset = tf.data.experimental.choose_from_datasets(class_datasets,
                                                        choice_dataset)

    # Episodes have a fixed size prescribed by `sampler.compute_chunk_sizes`.
    dataset = dataset.batch(sum(sampler.compute_chunk_sizes()))
    # Overlap batching and episode processing.
    dataset = dataset.prefetch(1)

    return dataset


class EpisodeReader(Reader, EpisodeReaderMixin):
  """Subclass of Reader assembling the examples as Episodes."""


def add_offset_to_target(example_strings, targets, offset):
  """Adds offset to the targets.

  This function is intented to be passed to tf.data.Dataset.map.

  Args:
    example_strings: 1-D Tensor of dtype str, Example protocol buffers.
    targets: 1-D Tensor of dtype int, targets representing the absolute class
      IDs.
    offset: int, optional, number to add to class IDs to get targets.

  Returns:
    example_strings, labels: Tensors, a batch of examples and labels.
  """
  labels = targets + offset
  return (example_strings, labels)


class BatchReaderMixin(object):
  """Mixin class to assemble examples as batches."""

  def create_dataset_input_pipeline(self,
                                    batch_size,
                                    offset=0,
                                    pool=None,
                                    shuffle_seed=None):
    """Creates a Dataset encapsulating the input pipeline for one data source.

    Args:
      batch_size: An int representing the max number of examples in each batch.
      offset: An int, that is added to the value of all the targets. This makes
        it possible to have a unique range of targets for each dataset.
      pool: A string (optional) indicating whether to only read examples from a
        given example-level split. If it is provided, these examples will be
        used as 'real test data', and used once each for evaluation only. The
        accepted values are 'valid' and 'test'.
      shuffle_seed: Optional, an int containing the seed passed to
        tf.data.Dataset.shuffle.

    Returns:
      dataset: a tf.data.Dataset instance which encapsulates batch creation for
        the data identified by `dataset_spec` and `split`. These batches contain
        compressed image representations and (possibly offset) absolute class
        IDs. A later part of the pipeline, shared across all sources, will
        decode the example strings.

    Raises:
      ValueError: Invalid pool provided. The supported values are 'valid' and
        'test'.
    """
    if pool and pool not in ['valid', 'test']:
      raise ValueError('Invalid pool provided. The supported values '
                       'are "valid" and "test".')
    # Do not shuffle or repeat each class dataset, to avoid fuzzing epoch
    # boundaries.
    class_datasets = self.construct_class_datasets(
        pool=pool, repeat=False, shuffle=False)
    num_classes = len(class_datasets)

    if pool:
      if not data.POOL_SUPPORTED:
        raise NotImplementedError(
            'Example-level splits or pools not supported.')
    else:
      # To have labels start at 0 and be contiguous, subtracting the starting
      # index from all
      start_ind = self.class_set[0]
      class_set = [
          self.class_set[ds_id] - start_ind for ds_id in range(num_classes)
      ]
      if list(class_set) != list(range(num_classes)):
        raise NotImplementedError('Batch training currently assumes the class '
                                  'set is contiguous and starts at 0.')

      # Sample from each class dataset according to its proportion of examples,
      # so examples from one class should be spread across the whole epoch.
      # Then, shuffle and repeat the combined dataset.
      num_examples_per_class = [
          self.dataset_spec.get_total_images_per_class(class_id, pool=pool)
          for class_id in class_set
      ]
      num_examples_per_class = np.array(num_examples_per_class, 'float64')
      class_proportions = num_examples_per_class / num_examples_per_class.sum()

      # Explicitly skip datasets with a weight of 0, as sample_from_datasets
      # can have some trouble with them.
      new_datasets_and_weights = [
          (dataset, weight)
          for (dataset, weight) in zip(class_datasets, class_proportions)
          if weight > 0
      ]
      class_datasets, class_proportions = zip(*new_datasets_and_weights)
      dataset = tf.data.experimental.sample_from_datasets(
          class_datasets, weights=class_proportions, seed=shuffle_seed)
      if self.shuffle_buffer_size and self.shuffle_buffer_size > 0:
        dataset = dataset.shuffle(
            buffer_size=self.shuffle_buffer_size,
            seed=shuffle_seed,
            reshuffle_each_iteration=True)

    # Using drop_remainder=False for two reasons:
    # - Most importantly, during established splits evaluation, we need to
    #   evaluate on all examples.
    # - Also during training, if the shuffle buffer does not hold all the data,
    #   the last examples are more likely to be dropped than the first ones.
    # In any case, we are handling variable-sized batches just fine, so there
    # is no real reason to drop data.
    dataset = dataset.batch(batch_size, drop_remainder=True)
    if not pool:
      dataset = dataset.repeat()

    if offset:
      map_fn = functools.partial(add_offset_to_target, offset=offset)
      dataset = dataset.map(map_fn)

    # Overlap batching and episode processing.
    dataset = dataset.prefetch(1)

    return dataset


class BatchReader(Reader, BatchReaderMixin):
  """Subclass of Reader assembling the examples as Batches."""
