# coding=utf-8
# Copyright 2022 The Meta-Dataset Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Sampling the composition of episodes.

The composition of episodes consists in the number of classes (num_ways), which
classes (relative class_ids), and how many examples per class (num_support,
num_query).

This module aims at replacing `sampler.py` in the new data pipeline.
"""
# TODO(lamblinp): Update variable names to be more consistent
# - target, class_idx, label

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import logging
from meta_dataset.data import dataset_spec as dataset_spec_lib
from meta_dataset.data import imagenet_specification
import numpy as np
from six.moves import zip

# Module-level random number generator. Initialized randomly, can be seeded.
RNG = np.random.RandomState(seed=None)

# How the value of MAX_SPANNING_LEAVES_ELIGIBLE was selected.
# This controls the upper bound on the number of leaves that an internal node
# may span in order for it to be eligible for selection. We found that this
# value is the minimum such value that allows every leaf to be reachable. By
# decreasing it, not all leaves would be reachable (therefore some classes would
# never be used). By increasing it, all leaves would still be reachable but we
# would sacrifice naturalness more than necessary (since when we sample an
# internal node that has more than MAX_HIERARCHICAL_CLASSES spanned leaves we
# sub-sample those leaves randomly which is essentially performing class
# selection without taking the hierarchy into account).
MAX_SPANNING_LEAVES_ELIGIBLE = 392


def sample_num_ways_uniformly(num_classes, min_ways, max_ways, rng=None):
  """Samples a number of ways for an episode uniformly and at random.

  The support of the distribution is [min_ways, num_classes], or
  [min_ways, max_ways] if num_classes > max_ways.

  Args:
    num_classes: int, number of classes.
    min_ways: int, minimum number of ways.
    max_ways: int, maximum number of ways. Only used if num_classes > max_ways.
    rng: np.random.RandomState used for sampling.

  Returns:
    num_ways: int, number of ways for the episode.
  """
  rng = rng or RNG
  max_ways = min(max_ways, num_classes)
  return rng.randint(low=min_ways, high=max_ways + 1)


def sample_class_ids_uniformly(num_ways, rel_classes, rng=None):
  """Samples the (relative) class IDs for the episode.

  Args:
    num_ways: int, number of ways for the episode.
    rel_classes: list of int, available class IDs to sample from.
    rng: np.random.RandomState used for sampling.

  Returns:
    class_ids: np.array, class IDs for the episode, with values in rel_classes.
  """
  rng = rng or RNG
  return rng.choice(rel_classes, num_ways, replace=False)


def compute_num_query(images_per_class, max_num_query, num_support):
  """Computes the number of query examples per class in the episode.

  Query sets are balanced, i.e., contain the same number of examples for each
  class in the episode.

  The number of query examples satisfies the following conditions:
  - it is no greater than `max_num_query`
  - if support size is unspecified, it is at most half the size of the
    smallest class in the episode
  - if support size is specified, it is at most the size of the smallest class
    in the episode minus the max support size.

  Args:
    images_per_class: np.array, number of images for each class.
    max_num_query: int, number of images for each class.
    num_support: int or tuple(int, int), number (or range) of support
      images per class.

  Returns:
    num_query: int, number of query examples per class in the episode.
  """
  if num_support is None:
    if images_per_class.min() < 2:
      raise ValueError('Expected at least 2 images per class.')
    return np.minimum(max_num_query, (images_per_class // 2).min())
  elif isinstance(num_support, int):
    max_support = num_support
  else:
    _, max_support = num_support
  if (images_per_class - max_support).min() < 1:
    raise ValueError(
        'Expected at least {} images per class'.format(max_support + 1))
  return np.minimum(max_num_query, images_per_class.min() - max_support)


def sample_support_set_size(num_remaining_per_class,
                            max_support_size_contrib_per_class,
                            max_support_set_size,
                            rng=None):
  """Samples the size of the support set in the episode.

  That number is such that:

  * The contribution of each class to the number is no greater than
    `max_support_size_contrib_per_class`.
  * It is no greater than `max_support_set_size`.
  * The support set size is greater than or equal to the number of ways.

  Args:
    num_remaining_per_class: np.array, number of images available for each class
      after taking into account the number of query images.
    max_support_size_contrib_per_class: int, maximum contribution for any given
      class to the support set size. Note that this is not a limit on the number
      of examples of that class in the support set; this is a limit on its
      contribution to computing the support set _size_.
    max_support_set_size: int, maximum size of the support set.
    rng: np.random.RandomState used for sampling.

  Returns:
    support_set_size: int, size of the support set in the episode.
  """
  rng = rng or RNG
  if max_support_set_size < len(num_remaining_per_class):
    raise ValueError('max_support_set_size is too small to have at least one '
                     'support example per class.')
  beta = rng.uniform()
  support_size_contributions = np.minimum(max_support_size_contrib_per_class,
                                          num_remaining_per_class)
  return np.minimum(
      # Taking the floor and adding one is equivalent to sampling beta uniformly
      # in the (0, 1] interval and taking the ceiling of its product with
      # `support_size_contributions`. This ensures that the support set size is
      # at least as big as the number of ways.
      np.floor(beta * support_size_contributions + 1).sum(),
      max_support_set_size)


def sample_num_support_per_class(images_per_class,
                                 num_remaining_per_class,
                                 support_set_size,
                                 min_log_weight,
                                 max_log_weight,
                                 rng=None):
  """Samples the number of support examples per class.

  At a high level, we wish the composition to loosely match class frequencies.
  Sampling is done such that:

  * The number of support examples per class is no greater than
    `support_set_size`.
  * The number of support examples per class is no greater than the number of
    remaining examples per class after the query set has been taken into
    account.

  Args:
    images_per_class: np.array, number of images for each class.
    num_remaining_per_class: np.array, number of images available for each class
      after taking into account the number of query images.
    support_set_size: int, size of the support set in the episode.
    min_log_weight: float, minimum log-weight to give to any particular class.
    max_log_weight: float, maximum log-weight to give to any particular class.
    rng: np.random.RandomState used for sampling.

  Returns:
    num_support_per_class: np.array, number of support examples for each class.
  """
  rng = rng or RNG
  if support_set_size < len(num_remaining_per_class):
    raise ValueError('Requesting smaller support set than the number of ways.')
  if np.min(num_remaining_per_class) < 1:
    raise ValueError('Some classes have no remaining examples.')

  # Remaining number of support examples to sample after we guarantee one
  # support example per class.
  remaining_support_set_size = support_set_size - len(num_remaining_per_class)

  unnormalized_proportions = images_per_class * np.exp(
      rng.uniform(min_log_weight, max_log_weight, size=images_per_class.shape))
  support_set_proportions = (
      unnormalized_proportions / unnormalized_proportions.sum())

  # This guarantees that there is at least one support example per class.
  num_desired_per_class = np.floor(
      support_set_proportions * remaining_support_set_size).astype('int32') + 1

  return np.minimum(num_desired_per_class, num_remaining_per_class)


class EpisodeDescriptionSampler(object):
  """Generates descriptions of Episode composition.

  In particular, for each Episode, it will generate the class IDs (relative to
  the selected split of the dataset) to include, as well as the number of
  support and query examples for each class ID.
  """

  def __init__(self,
               dataset_spec,
               split,
               episode_descr_config,
               pool=None,
               use_dag_hierarchy=False,
               use_bilevel_hierarchy=False,
               use_all_classes=False,
               ignore_hierarchy_probability=0.0):
    """Initializes an EpisodeDescriptionSampler.episode_config.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: one of Split.TRAIN, Split.VALID, or Split.TEST.
      episode_descr_config: An instance of EpisodeDescriptionConfig containing
        parameters relating to sampling shots and ways for episodes.
      pool: A string ('train' or 'test') or None, indicating which example-level
        split to select, if the current dataset has them.
      use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured
        ontology is defined in dataset_spec, use it to choose related classes.
      use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology
        is defined in dataset_spec, use it for sampling classes.
      use_all_classes: Boolean, defaults to False. Uses all available classes,
        in order, instead of sampling. Overrides `num_ways` to the number of
        classes in `split`.
      ignore_hierarchy_probability: Float, if using a hierarchy, this flag makes
        the sampler ignore the hierarchy for this proportion of episodes and
        instead sample categories uniformly.

    Raises:
      RuntimeError: if required parameters are missing.
      ValueError: Inconsistent parameters.
    """
    # Each instance has its own RNG which is seeded from the module-level RNG,
    # which makes episode description sampling deterministic within individual
    # data sources.
    self._rng = np.random.RandomState(
        seed=RNG.randint(0, 2**32, size=None, dtype='uint32'))
    self.dataset_spec = dataset_spec
    self.split = split
    self.pool = pool
    self.use_dag_hierarchy = use_dag_hierarchy
    self.use_bilevel_hierarchy = use_bilevel_hierarchy
    self.ignore_hierarchy_probability = ignore_hierarchy_probability
    self.use_all_classes = use_all_classes
    self.num_ways = episode_descr_config.num_ways
    self.num_support = episode_descr_config.num_support
    self.num_query = episode_descr_config.num_query
    self.min_ways = episode_descr_config.min_ways
    self.max_ways_upper_bound = episode_descr_config.max_ways_upper_bound
    self.max_num_query = episode_descr_config.max_num_query
    self.max_support_set_size = episode_descr_config.max_support_set_size
    self.max_support_size_contrib_per_class = episode_descr_config.max_support_size_contrib_per_class
    self.min_log_weight = episode_descr_config.min_log_weight
    self.max_log_weight = episode_descr_config.max_log_weight
    self.min_examples_in_class = episode_descr_config.min_examples_in_class
    self.episode_description_switch_frequency = episode_descr_config.episode_description_switch_frequency

    self.class_set = dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)
    # Filter out classes with too few examples
    self._filtered_class_set = []
    # Store (class_id, n_examples) of skipped classes for logging.
    skipped_classes = []
    for class_id in self.class_set:
      n_examples = dataset_spec.get_total_images_per_class(class_id, pool=pool)
      if n_examples < self.min_examples_in_class:
        skipped_classes.append((class_id, n_examples))
      else:
        self._filtered_class_set.append(class_id)
    self.num_filtered_classes = len(self._filtered_class_set)

    if skipped_classes:
      logging.info(
          'Skipping the following classes, which do not have at least '
          '%d examples', self.min_examples_in_class)
    for class_id, n_examples in skipped_classes:
      logging.info('%s (ID=%d, %d examples)',
                   dataset_spec.class_names[class_id], class_id, n_examples)

    if self.min_ways and self.num_filtered_classes < self.min_ways:
      raise ValueError(
          '"min_ways" is set to {}, but split {} of dataset {} only has {} '
          'classes with at least {} examples ({} total), so it is not possible '
          'to create an episode for it. This may have resulted from applying a '
          'restriction on this split of this dataset by specifying '
          'benchmark.restrict_classes or benchmark.min_examples_in_class.'
          .format(self.min_ways, split, dataset_spec.name,
                  self.num_filtered_classes, self.min_examples_in_class,
                  self.num_classes))

    if self.use_all_classes:
      if self.num_classes != self.num_filtered_classes:
        raise ValueError('"use_all_classes" is not compatible with a value of '
                         '"min_examples_in_class" ({}) that results in some '
                         'classes being excluded.'.format(
                             self.min_examples_in_class))
      self.num_ways = self.num_classes

    # Maybe overwrite use_dag_hierarchy or use_bilevel_hierarchy if requested.
    if episode_descr_config.ignore_dag_ontology:
      self.use_dag_hierarchy = False
    if episode_descr_config.ignore_bilevel_ontology:
      self.use_bilevel_hierarchy = False

    # For Omniglot.
    if self.use_bilevel_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"num_ways".')
      if self.min_examples_in_class > 0:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"min_examples_in_class".')
      if self.use_dag_hierarchy:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"use_dag_hierarchy".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.BiLevelDatasetSpecification):
        raise ValueError('Only applicable to datasets with a bi-level '
                         'dataset specification.')
      # The id's of the superclasses of the split (a contiguous range of ints).
      all_superclasses = dataset_spec.get_superclasses(self.split)
      self.superclass_set = []
      for i in all_superclasses:
        if self.dataset_spec.classes_per_superclass[i] < self.min_ways:
          raise ValueError(
              'Superclass: %d has num_classes=%d < min_ways=%d.' %
              (i, self.dataset_spec.classes_per_superclass[i], self.min_ways))
        self.superclass_set.append(i)
    # For ImageNet.
    elif self.use_dag_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.HierarchicalDatasetSpecification):
        raise ValueError('Only applicable to datasets with a hierarchical '
                         'dataset specification.')

      # A DAG for navigating the ontology for the given split.
      graph = dataset_spec.get_split_subgraph(self.split)

      # Map the absolute class IDs in the split's class set to IDs relative to
      # the split.
      class_set = self.class_set
      abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set))

      # Extract the sets of leaves and internal nodes in the DAG.
      leaves = set(imagenet_specification.get_leaves(graph))
      internal_nodes = graph - leaves  # set difference

      # Map each node of the DAG to the Synsets of the leaves it spans.
      spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph)

      # Build a list of lists storing the relative class IDs of the spanning
      # leaves for each eligible internal node. We ensure a deterministic order
      # by sorting the inner-nodes and their corresponding leaves by wn_id.
      self.span_leaves_rel = []
      for node in sorted(internal_nodes, key=lambda n: n.wn_id):
        node_leaves = sorted(spanning_leaves_dict[node], key=lambda n: n.wn_id)
        # Build a list of relative class IDs of leaves that have at least
        # min_examples_in_class examples.
        ids_rel = []
        for leaf in node_leaves:
          abs_id = dataset_spec.class_names_to_ids[leaf.wn_id]
          if abs_id in self._filtered_class_set:
            ids_rel.append(abs_to_rel_ids[abs_id])

        # Internal nodes are eligible if they span at least
        # `min_allowed_classes` and at most `max_eligible` leaves.
        if self.min_ways <= len(ids_rel) <= MAX_SPANNING_LEAVES_ELIGIBLE:
          self.span_leaves_rel.append(ids_rel)

      num_eligible_nodes = len(self.span_leaves_rel)
      if num_eligible_nodes < 1:
        raise ValueError('There are no classes eligible for participating in '
                         'episodes. Consider changing the value of '
                         '`EpisodeDescriptionSampler.min_ways` in gin, or '
                         'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')

  def sample_class_ids(self):
    """Returns the (relative) class IDs for an episode.

    If self.use_dag_hierarchy, it samples them according to a procedure
    informed by the dataset's ontology, otherwise randomly.
    If self.min_examples_in_class > 0, classes with too few examples will not
    be selected.
    """
    prob = [1.0, 0.0]
    if self.ignore_hierarchy_probability:
      prob = [
          1.0 - self.ignore_hierarchy_probability,
          self.ignore_hierarchy_probability
      ]

    if self.use_dag_hierarchy and self._rng.choice([True, False], p=prob):
      # Retrieve the list of relative class IDs for an internal node sampled
      # uniformly at random.
      episode_classes_rel = self._rng.choice(self.span_leaves_rel)

      # If the number of chosen classes is larger than desired, sub-sample them.
      if len(episode_classes_rel) > self.max_ways_upper_bound:
        episode_classes_rel = self._rng.choice(
            episode_classes_rel,
            size=[self.max_ways_upper_bound],
            replace=False)

      # Light check to make sure the chosen number of classes is valid.
      assert len(episode_classes_rel) >= self.min_ways
      assert len(episode_classes_rel) <= self.max_ways_upper_bound
    elif self.use_bilevel_hierarchy and self._rng.choice([True, False], p=prob):
      # First sample a coarse category uniformly. Then randomly sample the way
      # uniformly, but taking care not to sample more than the number of classes
      # of the chosen supercategory.
      episode_superclass = self._rng.choice(self.superclass_set, 1)[0]
      num_superclass_classes = self.dataset_spec.classes_per_superclass[
          episode_superclass]

      num_ways = sample_num_ways_uniformly(
          num_superclass_classes,
          min_ways=self.min_ways,
          max_ways=self.max_ways_upper_bound,
          rng=self._rng)

      # e.g. if these are [3, 1] then the 4'th and the 2'nd of the subclasses
      # that belong to the chosen superclass will be used. If the class id's
      # that belong to this superclass are [23, 24, 25, 26] then the returned
      # episode_classes_rel will be [26, 24] which as usual are number relative
      # to the split.
      episode_subclass_ids = sample_class_ids_uniformly(
          num_ways, num_superclass_classes, rng=self._rng)
      (episode_classes_rel,
       _) = self.dataset_spec.get_class_ids_from_superclass_subclass_inds(
           self.split, episode_superclass, episode_subclass_ids)
    elif self.use_all_classes:
      episode_classes_rel = np.arange(self.num_classes)
    else:  # No type of hierarchy is used. Classes are randomly sampled.
      if self.num_ways is not None:
        num_ways = self.num_ways
      else:
        num_ways = sample_num_ways_uniformly(
            self.num_filtered_classes,
            min_ways=self.min_ways,
            max_ways=self.max_ways_upper_bound,
            rng=self._rng)
      # Filtered class IDs relative to the selected split
      ids_rel = [
          class_id - self.class_set[0] for class_id in self._filtered_class_set
      ]
      episode_classes_rel = sample_class_ids_uniformly(
          num_ways, ids_rel, rng=self._rng)

    return episode_classes_rel

  def sample_episode_description(self):
    """Returns the composition of an episode.

    Returns:
      A sequence of `(class_id, num_support, num_query)` tuples, where
        relative `class_id` is an integer in [0, self.num_classes).
    """
    class_ids = self.sample_class_ids()
    images_per_class = np.array([
        self.dataset_spec.get_total_images_per_class(
            self.class_set[cid], pool=self.pool) for cid in class_ids
    ])

    if self.num_query is not None:
      num_query = self.num_query
    else:
      num_query = compute_num_query(
          images_per_class,
          max_num_query=self.max_num_query,
          num_support=self.num_support)

    if self.num_support is not None:
      if isinstance(self.num_support, int):
        if any(self.num_support + num_query > images_per_class):
          raise ValueError('Some classes do not have enough examples.')
        num_support = self.num_support
      else:
        start, end = self.num_support
        if any(end + num_query > images_per_class):
          raise ValueError('The range provided for uniform sampling of the '
                           'number of support examples per class is not valid: '
                           'some classes do not have enough examples.')
        num_support = self._rng.randint(low=start, high=end + 1)
      num_support_per_class = [num_support for _ in class_ids]
    else:
      num_remaining_per_class = images_per_class - num_query
      support_set_size = sample_support_set_size(
          num_remaining_per_class,
          self.max_support_size_contrib_per_class,
          max_support_set_size=self.max_support_set_size,
          rng=self._rng)
      num_support_per_class = sample_num_support_per_class(
          images_per_class,
          num_remaining_per_class,
          support_set_size,
          min_log_weight=self.min_log_weight,
          max_log_weight=self.max_log_weight,
          rng=self._rng)

    return tuple(
        (class_id, num_support, num_query)
        for class_id, num_support in zip(class_ids, num_support_per_class))

  def compute_chunk_sizes(self):
    """Computes the maximal sizes for the flush, support, and query chunks.

    Sequences of dataset IDs are padded with placeholder IDs to make sure they
    can be batched into episodes of equal sizes.

    The "flush" part of the sequence has a size that is upper-bounded by the
    size of the "support" and "query" parts.

    If variable, the size of the "support" part is in the worst case

        max_support_set_size,

    and the size of the "query" part is in the worst case

        max_ways_upper_bound * max_num_query.

    Returns:
      The sizes of the flush, support, and query chunks.
    """
    if self.num_ways is None:
      max_num_ways = self.max_ways_upper_bound
    else:
      max_num_ways = self.num_ways

    if self.num_support is None:
      support_chunk_size = self.max_support_set_size
    elif isinstance(self.num_support, int):
      support_chunk_size = max_num_ways * self.num_support
    else:
      largest_num_support_per_class = self.num_support[1]
      support_chunk_size = max_num_ways * largest_num_support_per_class

    if self.num_query is None:
      max_num_query = self.max_num_query
    else:
      max_num_query = self.num_query
    query_chunk_size = max_num_ways * max_num_query

    flush_chunk_size = support_chunk_size + query_chunk_size
    return (flush_chunk_size, support_chunk_size, query_chunk_size)
