# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data structures bundling inputs to the dataset generation process."""

import collections
import dataclasses
import enum
import time
from typing import (Dict, FrozenSet, Iterable, Iterator, List, Mapping,
                    Optional, Sequence, Tuple, TypeVar)

import dataclasses_json
import immutabledict
import numpy as np
from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import grammar_loader
from conceptual_learning.cscan import induction

# Original input vocabularies of SCAN and MiniSCAN.
SCAN_INPUT_VOCABULARY = frozenset([
    'walk', 'look', 'run', 'jump', 'left', 'right', 'turn', 'around',
    'opposite', 'and', 'after', 'twice', 'thrice'
])
MINISCAN_INPUT_VOCABULARY = frozenset(
    ['dax', 'lug', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer', 'zup', 'wif'])
# Extended input vocabulary for SCAN.
SCAN_EXTENDED_INPUT_VOCABULARY = frozenset([
    'before', 'following', 'between', 'framing', '2x', '3x', '4x', '5x', '90',
    '180', '270', '360', 'fast', 'cautiously', 'drunkenly', 'zigzag', 'drive',
    'ride', 'fly', 'leap', 'peek', 'up', 'down'
])
# A simple synthetic input vocabulary for use in cases where a larger vocabulary
# is needed.
INPUT_VOCABULARY_SIZE_100 = frozenset(f'i{num}' for num in range(100))

# Original output vocabularies of SCAN and MiniSCAN.
SCAN_OUTPUT_VOCABULARY = frozenset(
    ['WALK', 'LOOK', 'RUN', 'JUMP', 'LTURN', 'RTURN'])
MINISCAN_OUTPUT_VOCABULARY = frozenset(
    ['RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK', 'BLACK', 'WHITE'])
# Extended output vocabulary for SCAN.
SCAN_EXTENDED_OUTPUT_VOCABULARY = frozenset(
    ['UTURN', 'DTURN', 'DRIVE', 'RIDE', 'FLY', 'LEAP', 'PEEK'])
# A simple synthetic output vocabulary for use in cases where a larger
# vocabulary is needed.
OUTPUT_VOCABULARY_SIZE_100 = frozenset(f'o{num}' for num in range(100))

# Categories available for use in each precedence level. Note that while we may
# allow the number of categories per level to vary from grammar to grammar,
# we keep the relative precedence of the different categories fixed so as to
# provide an unambiguous way of asking the learner whether a given grammar
# rule (referencing some specific categories) is true or not, without
# providing the full grammar to the learner.
#
# Several standard configurations are provided, covering a range of degrees of
# variability around the standard SCAN grammar.
#
# Note also that the exact strings used to represent the syntactic categories
# are unimportant, as long as they are stable. However, in order to ensure that
# the original SCAN grammar (or more precisely, the grammar described by
# scan_finite_nye_standardized.fcfg) is included in the space of grammars
# representable by GrammarSchema, we do include in each standard configuration
# the syntactic categories used in scan_finite_nye_standardized.fcfg for each
# level ('U', 'W', 'D', 'V', 'S', 'C').

# Most basic configuration. Syntactic categories up through level 4 are
# identical to those used in the Nye variant of the original SCAN grammar, and
# with the exception of level 0, only one category is allowed per level.
POSSIBLE_CATEGORIES_BY_LEVEL_MINIMAL = immutabledict.immutabledict({
    # Levels used in the original SCAN grammars (with original categories)
    0: ('U', 'W'),
    1: ('D',),
    2: ('V',),
    3: ('S',),
    4: ('C',),
    # Additional levels
    5: ('E',),
    6: ('F',),
    7: ('G',),
    8: ('H',),
})

# Configuration allowing for greater variability in categories used in each
# level. Includes the categories from scan_finite_nye_standardized.fcfg for each
# level and as a mnemonic simply adds numerical suffixes to these in cases where
# we need more categories at the same level.
POSSIBLE_CATEGORIES_BY_LEVEL_8_PER_LEVEL = immutabledict.immutabledict({
    # Levels used in the original SCAN grammars
    0: ('U', 'W', 'U1', 'U2', 'U3', 'W1', 'W2', 'W3'),
    1: ('D', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7'),
    2: ('V', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7'),
    3: ('S', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7'),
    4: ('C', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7'),
    # Additional levels
    5: ('E', 'E1', 'E2', 'E3', 'E4', 'E5', 'E6', 'E7'),
    6: ('F', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7'),
    7: ('G', 'G1', 'G2', 'G3', 'G4', 'G5', 'G6', 'G8'),
    8: ('H', 'H1', 'H2', 'H3', 'H4', 'H5', 'H6', 'H7'),
})

_K = TypeVar('_K')
_V = TypeVar('_V')


def _invert_mapping_of_sequences(
    original_mapping):
  """Returns an inversion of the given mapping.

  Assumes that in the original mapping, while there can be multiple values for
  the same key, there should be only one key for any given value, so that the
  inverted mapping can be a simple mapping of key to value rather than of key
  to sequence of values.

  Args:
    original_mapping: Mapping to be inverted.
  """
  reverse_mapping = {}
  for key, values in original_mapping.items():
    for value in values:
      if value in reverse_mapping:
        raise ValueError(
            f'Same value unexpectedly duplicated across multiple keys: '
            f'value={value}, key1={reverse_mapping[value]}, key2={key}, '
            f'mapping={original_mapping}')
      reverse_mapping[value] = key
  return reverse_mapping


def _dict_of_tuples_from_dict_of_lists(
    x):
  return {k: tuple(v) for k, v in x.items()}


def _decode_possible_categories_by_level(
    dict_from_json
):
  # Without performing this transformation, possible_categories_by_level would
  # get restored from JSON as
  #   {'0': ['U', 'W'], '1': ['D'], '2': ['V'], '3': ['S']}
  # instead of
  #   {0: ('U', 'W'), 1: ('D',), 2: ('V',), 3: ('S',)}
  return immutabledict.immutabledict(
      {int(k): tuple(v) for k, v in dict_from_json.items()})


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class GrammarOptions:
  """Options controlling grammar generation.

  Each option has a default value, so that a GrammarOptions constructed with
  no arguments will lead to some sort of reasonable behavior.

  Note that the inputs to GrammarGenerator include both these GrammarOptions
  and also an optional GrammarSchema template. If a GrammarSchema template is
  provided, then the GrammarOptions will apply only to the population of
  whatever portions of the template were left unspecified.

  Attributes:
    input_vocabulary: Set of input tokens that are allowed to be used.
    output_vocabulary: Set of output tokens that are allowed to be used.
    possible_categories_by_level: Mapping of precedence level to the syntactic
      categories that are allowed to be output by that level.
    min_num_primitives: Minimum number of primitive mappings to generate.
      Ignored if num_primitives is specified.
    max_num_primitives: Maximum number of primitive mappings to generate.
      Ignored if num_primitives is specified.
    num_primitives: Exact number of primitive mappings to generate. If None,
      then GrammarGenerator will select a random number between the min and max.
    min_num_precedence_levels: Minimum number of precedence levels for which to
      generate non-primitive rules. Doesn't include precedence level 0. Ignored
      if num_precedence_levels is specified.
    max_num_precedence_levels: Maximum number of precedence levels for which to
      generate non-primitive rules. Doesn't include precedence level 0. Ignored
      if num_precedence_levels is specified.
    num_precedence_levels: Exact number of precedence levels for which to
      generate non-primitive rules. Doesn't include precedence level 0. If None,
      then GrammarGenerator will select a random number between the min and max.
    min_num_categories_per_level: Minimum number of syntactic categories to be
      output by any given precedence level, including level 0 (primitives).
      Ignored for the largest precedence level, which must always output exactly
      one category.
    max_num_categories_per_level: Maximum number of syntactic categories to be
      output by any given precedence level, including level 0 (primitives).
    min_num_functions_per_level: Minimum number of function rules to generate
      for any given precedence level, in the absence of any other constraints
      that may force the number of function rules to be even higher -- e.g., the
      number of categories selected to be output by the given level, and the
      need for there to be at least one non-PassThroughRule outputting any given
      category.
    max_num_functions_per_level: Maximum number of function rules to generate
      for any given precedence level. May be ignored in certain cases, if there
      are other constraints that would require the number of function rules to
      be higher. (See explanation of min_num_functions_per_level above.)
    min_num_args: Minimum number of arguments to generate for any given function
      rule. Must be at least 1.
    max_num_args: Maximum number of arguments to generate for any given function
      rule.
    min_num_postfix_args: Minimum number of arguments for any given function
      that should appear after the function phrase. Any remaining arguments
      would be considered prefix args (appearing before the function phrase).
    max_num_postfix_args: Maximum number of arguments for any given function
      that should appear after the function phrase.
    prob_pass_through_rule: Probability that any given precedence level will
      contain a PassThroughRule.
    prob_concat_rule: Probability that the grammar as a whole will contain a
      ConcatRule.
    max_repetitions_per_token_in_output_sequence: The maximum number of times
      that any given variable or raw output token is allowed to be repeated in
      the output sequence. For each output token, GrammarGenerator will select a
      number of repetitions between 1 and this maximum.
    max_output_sequence_size: The maximum number of total tokens (variables and
      output vocabulary) to include in the output sequence of any given
      FunctionRule or ConcatRule. A value < 0 means that there is no limit.
    max_cumulative_output_sequence_size: The target maximum output size of all
      the rules in a context. Note that this maximum is approximated rather than
      strictly enforced. A value < 0 means that there is no cumulative limit.
    min_unique_raw_tokens_in_output_sequence: The minimum number of unique
      tokens from the output vocabulary to include in the output sequence of any
      given FunctionRule or ConcatRule.
    max_unique_raw_tokens_in_output_sequence: The maximum number of unique
      tokens from the output vocabulary to include in the output sequence of any
      given FunctionRule or ConcatRule.
    max_raw_tokens_in_output_sequence: The maximum number of tokens from the
      output vocabulary to include in the output sequence of any given
      FunctionRule or ConcatRule. A value < 0 means that there is no limit.
    reuse_raw_tokens_in_output_sequence: Indicates whether raw tokens may be
      used in the output sequences of multiple rules.
    validate_concat_rule_level: A boolean indicating whether the concat rule
      arguments must be from exactly one level below. It is recommended to set
      this to True when generating random phrase structure grammars, to help
      avoid accidental generation of ambiguous grammars; when using a fixed
      phrase structure grammar template, this check is not necessary.
  """

  # Sets of input and output tokens that are allowed to be used.
  input_vocabulary: FrozenSet[str] = dataclasses.field(
      default=frozenset(
          sorted(list(SCAN_INPUT_VOCABULARY | MINISCAN_INPUT_VOCABULARY))),
      # Without the decoder specified explicitly, the ordering of the elements
      # in the FrozenSet could change when restoring from JSON, which would
      # interfere with equality checks in the unit tests.
      metadata=dataclasses_json.config(
          decoder=lambda x: frozenset(sorted(list(x)))))
  output_vocabulary: FrozenSet[str] = dataclasses.field(
      default=frozenset(
          sorted(list(SCAN_OUTPUT_VOCABULARY | MINISCAN_OUTPUT_VOCABULARY))),
      metadata=dataclasses_json.config(
          decoder=lambda x: frozenset(sorted(list(x)))))

  # E.g., in scan_finite_nye.fcfg, level 0 outputs the syntactic categories 'U'
  # and 'W', level 1 outputs 'D', etc'.
  possible_categories_by_level: immutabledict.immutabledict = dataclasses.field(
      default=immutabledict.immutabledict(POSSIBLE_CATEGORIES_BY_LEVEL_MINIMAL),
      metadata=dataclasses_json.config(
          decoder=_decode_possible_categories_by_level))

  # E.g. in SCAN, 'walk', 'look', 'run', 'jump' (also in scan_finite_nye,
  # 'turn', 'left', and 'right').
  min_num_primitives: int = 7
  max_num_primitives: int = 7
  num_primitives: Optional[int] = None

  # E.g., 'twice' and 'thrice' are function words of the same precedence level
  # (level 3 in scan_finite_nye.fcfg), which bind more tightly than 'and' and
  # 'after' (level 4 in scan_finite_nye.fcfg).
  min_num_precedence_levels: int = 4
  max_num_precedence_levels: int = 4
  num_precedence_levels: Optional[int] = None

  # E.g., in scan_finite_nye.fcfg, there are two categories (U and W) used for
  # pre-terminals (level 0), and then one category each for the remaining
  # precedence levels (D for level 1, V for level 2, S for level 3, and C for
  # level 4).
  min_num_categories_per_level: int = 1
  max_num_categories_per_level: int = 2

  # E.g., in the original SCAN grammar, 'twice' is a function that takes a
  # single prefix arg (i.e., 1 arg, 0 of which are postfix), while 'after' is a
  # function that takes one prefix arg and one postfix arg (i.e., 2 args, 1 of
  # which is postfix).
  min_num_functions_per_level: int = 0
  max_num_functions_per_level: int = 2
  min_num_args: int = 1
  max_num_args: int = 2
  min_num_postfix_args: int = 0
  max_num_postfix_args: int = 1

  # Probability that any given precedence level will contain a PassThroughRule.
  # E.g., in scan_finite_nye.fcfg, every level contains a PassThroughRule.
  prob_pass_through_rule: float = 1

  # Probability that the grammar as a whole will contain a ConcatRule.
  # E.g., in scan_finite_nye.fcfg, the ConcatRule is in level 1:
  #   D[sem=(?w+?u)] -> U[sem=?u] W[sem=?w]
  prob_concat_rule: float = 1

  # The output sequence for any given function is a sequence of variables
  # (corresponding to function args) and/or tokens from the output vocabulary.
  # As we do not allow function arguments to be ignored, the output sequence
  # must be at least as long as the number of function arguments, but otherwise
  # could theoretically be of arbitrary length, as the output sequence can
  # include repetition.
  max_repetitions_per_token_in_output_sequence: int = 4
  max_output_sequence_size: int = -1
  max_cumulative_output_sequence_size: int = -1

  # In scan_finite_nye.fcfg, non-primitive rules never contain raw tokens.
  # In scan_finite.cfg, the following rule, for example, contains one unique
  # raw token ('LTURN'):
  #   V[sem=('LTURN'+'LTURN'+?u)] -> U[sem=?u] 'opposite' 'left'
  min_unique_raw_tokens_in_output_sequence: int = 0
  max_unique_raw_tokens_in_output_sequence: int = 0
  max_raw_tokens_in_output_sequence: int = -1
  reuse_raw_tokens_in_output_sequence: bool = False

  validate_concat_rule_level: bool = False

  def level_by_category(self):
    """Mapping of syntactic category to the precedence level that outputs it."""
    return immutabledict.immutabledict(
        _invert_mapping_of_sequences(self.possible_categories_by_level))


@enum.unique
class NegativeExampleStrategy(str, enum.Enum):
  """Supported strategies for creating negative example.

  This is used in the distractor_generation module.

  HEURISTIC_EDIT:
    Modify the output sequence of the positive rule by adding, removing,
    repeating, or replacing a variable or a non-variable token.
  ALTERNATIVE_GRAMMAR:
    Replace a source production used to build the positive rule with one from
    a newly generated grammar.
  """
  HEURISTIC_EDIT = 'HEURISTIC_EDIT'
  ALTERNATIVE_GRAMMAR = 'ALTERNATIVE_GRAMMAR'


def _round_while_preserving_sum(x):
  """Returns the given number rounded in a way that preserves their sum.

  Examples:
    * [0.67, 0.67, 0.66] -> [1, 0, 1] (rather than [1, 1, 1])
    * [0.33, 0.33, 0.34] -> [0, 1, 0] (rather than [0, 0, 0])

  Args:
    x: The numbers to be rounded.
  """
  rounded = np.rint(np.cumsum(list(x))).astype(int)
  rounded[1:] -= rounded[:-1]  # Go from accumulated values back to diffs.
  return list(map(int, rounded))


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class ExampleLengthOptions:
  """Options controlling maximum tokenized length of top-level examples.

  If any "max" value is set to 0, then no limit will be applied.

  Attributes:
    max_input_length_standard: Maximum input length in STANDARD string
      representation using whitespace tokenization.
    max_output_length_standard: Maximum output length in STANDARD string
      representation using whitespace tokenization.
    max_input_length_compact: Maximum input length in COMPACT string
      representation using T5's default SentencePiece tokenization.
    max_output_length_compact: Maximum output length in COMPACT string
      representation using T5's default SentencePiece tokenization.
  """

  max_input_length_standard: int = 0
  max_output_length_standard: int = 0
  max_input_length_compact: int = 0
  max_output_length_compact: int = 0


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class SamplingOptions:
  """Options controlling random sampling of context and examples.

  Each option has a default value, so that a SamplingOptions constructed with
  no arguments will lead to some sort of reasonable behavior.

  Attributes:
    num_contexts: Number of distinct contexts to generate.
    num_requests_per_context: Number of requests (i.e., top-level examples) to
      generate for any given context.
    num_rules: Mean of the distribution from which to sample the number of
      grammar rules that are used in any given context. A value < 0 indicates
      that all grammar rules are kept.
    num_rules_min: Minimum number of grammar rules that are kept for any
      context. This is ignored if num_rules < 0.
    num_rules_max: Maximum number of grammar rules that are kept for any
      context. This is ignored if num_rules < 0.
    num_rules_stddev: Standard deviation of the distribution from which to
      sample the number of grammar rules that are used in any given context.
      This is ignored if num_rules < 0.
    num_examples_per_hidden_rule: Number of examples to include in the context
      to illustrate each hidden grammar rule.
    illustrative_example_non_rule_fraction: Of the examples added to the context
      for illustrating a hidden rule, the fraction of these that are non-rule
      examples, as opposed to derived rule examples.
    omitted_fraction: Mean of the distribution from which to sample the fraction
      of rules to be omitted (not illustrated in any way) in the context.
      Examples that depend on omitted rules will have "?" as reply.
    omitted_fraction_stddev: Standard deviation of the distribution from which
      to sample the fraction of rules to be omitted in the context.
    unreliable_fraction: Mean of the distribution from which to sample the
      fraction of rules to be made unreliable in the context.  An unreliable
      rule is always replaced with a distractor rule whenever used for
      generating examples, including when the unreliable rule is the target
      rule.  Top-level examples that depend on unreliable rules will have
      UNKNOWN as reply.
    unreliable_fraction_stddev: Standard deviation of the distribution from
      which to sample the fraction of rules to be made unreliable in the
      context.
    explicit_fraction: Mean of the distribution from which to sample the
      fraction of rules to show explicitly in the context.
    explicit_fraction_stddev: Standard deviation of the distribution from which
      to sample the fraction of rules to show explicitly in the context.
    min_explicit_fraction: The minimum fraction of rules to show explicitly in
      the context.
    non_rule_fraction: Fraction of examples that have non-rule requests as
      opposed to rule requests.
    negative_example_fraction: Mean of the distribution from which to sample the
      fraction of rule requests to generate with FALSE reply.
    negative_example_fraction_stddev: Standard deviation of the distribution
      from which to sample the fraction of rule requests to generate with FALSE
      reply.
    defeasible_example_fraction: Mean of the distribution from which to sample
      the target fraction of top-level examples with qualifier D.
    defeasible_example_fraction_stddev: Standard deviation of the distribution
      from which to sample the target fraction of top-level examples with
      qualifier D.
    unknown_example_fraction: Mean of the distribution from which to sample the
      target fraction of top-level examples with UNKNOWN reply.
    unknown_example_fraction_stddev: Standard deviation of the distribution from
      which to sample the target fraction of top-level examples with UNKNOWN
      reply.
    additional_test_and_validation_requests_per_context: How many more examples
      per context should be generated for both test and validation splits.
    max_generation_time_per_request_type: Maximum number of seconds to spend on
      generating top level examples of each request type.
    max_derivation_level: Maximum derivation level when sampling derived
      productions.
    max_attempts_per_example: If no valid example of a certain type (e.g., an
      example illustrating a given rule or a negative distractor example based
      off of a given positive rule) could be generated after the given number of
      attempts, the algorithm will give up on generating an example of that type
      and move on with the generation of the rest of the dataset.
    max_attempts_per_context: The maximum number of attempts for generating a
      context with a given grammar.
    max_attempts_per_grammar: The maximum number of attempts for generating a
      grammar with a given grammar template. Dataset generation will be aborted
      if grammar generation fails this many times in a row for the given
      DatasetSpec.
    rule_format: The format used for representing grammar rules as rule requests
      in generated examples.
    alternative_grammar_fraction: The fraction of negative examples generated
      using the alternative grammar strategy.
    max_attempts_per_negative_example: The number of attempts for generating a
      negative example.
    derived_production_yield_probability: The probability of yielding a derived
      production instead of further expanding nonterminals, at any given step of
      the derived production sampling process.
    max_edits: The maximum number of heuristic edits for generating negative
      rule examples.
    inductive_bias: Inductive bias that encapsulates the criteria for inducing a
      rule from a given context.
    require_hidden_rules_to_satisfy_inductive_bias: If True, then will filter
      out any contexts in which any of the hidden rules failed to satisfy the
      inductive bias. If False, then will keep those contexts and simply treat
      the affected hidden rules as being "unknown" rather than "true".
    lengths: Options controlling maximum tokenized length of top-level examples.
  """

  num_contexts: int = 1
  num_requests_per_context: int = 1
  num_rules: int = 14
  num_rules_min: int = 14
  num_rules_max: int = 14
  num_rules_stddev: float = 0
  num_examples_per_hidden_rule: int = 4
  illustrative_example_non_rule_fraction: float = 0.5

  # Context options.
  # Depending on the implementation in dataset_generation, these need not be the
  # actual mean and standard deviations of the distributions used to sample the
  # fractions (e.g., they could be parameters passed to a truncated normal
  # distribution).
  omitted_fraction: float = 0.015
  omitted_fraction_stddev: float = 0.01
  unreliable_fraction: float = 0.04
  unreliable_fraction_stddev: float = 0.01
  explicit_fraction: float = 0.5
  explicit_fraction_stddev: float = 0.5
  min_explicit_fraction: float = 0.0

  # Top-level example options.
  non_rule_fraction: float = 0.5
  negative_example_fraction: float = 0.5
  negative_example_fraction_stddev: float = 0.1
  defeasible_example_fraction: float = 0.5
  defeasible_example_fraction_stddev: float = 0.1
  unknown_example_fraction: float = 0.2
  unknown_example_fraction_stddev: float = 0.05
  additional_test_and_validation_requests_per_context: int = -1
  # Set the default value to 2 days.
  max_generation_time_per_request_type: int = 2 * 24 * 60 * 60

  max_derivation_level: int = 1000
  max_attempts_per_example: int = 400
  max_attempts_per_context: int = 10
  max_attempts_per_grammar: int = 5

  rule_format: enums.RuleFormat = enums.RuleFormat.INTERPRETATION_RULE

  alternative_grammar_fraction: float = 0.1
  max_attempts_per_negative_example: int = 20
  max_edits: int = 1

  derived_production_yield_probability: float = 0.75

  require_hidden_rules_to_satisfy_inductive_bias: bool = True
  inductive_bias: induction.InductiveBias = (
      dataclasses.field(
          default_factory=induction.IllustrativeSubstitutionsInductiveBias,
          metadata=dataclasses_json.config(
              encoder=induction.InductiveBias.json_encode,
              decoder=induction.InductiveBias.json_decode)))

  lengths: ExampleLengthOptions = dataclasses.field(
      default_factory=ExampleLengthOptions)

  def calculate_schedule_of_examples_by_type(
      self,
      num_requests_per_context = None
  ):
    """Yields the number of examples of each type to generate for each context.

    Each mapping that is yielded represents the breakdown of examples for one
    context. Thus the number of mappings yielded equals the number of contexts
    that are to be generated. As much as possible, the same ratio of request
    types will be applied in each context; however, if the ideal number of
    examples of a given type per context would be fractional, the exact integer
    value will be adjusted appropriately so as to stay as close as possible
    to the desired ratio of request types in aggregate across contexts.

    This method yields an infinite stream of number-of-examples-by-type
    mappings.

    Args:
      num_requests_per_context: Optionally provide a num_requests_per_context
        target. This is usefule for when generating additional examples for each
        context. If left unspecified the value of self.num_requests_per_context
        will be used.
    """
    if num_requests_per_context is None:
      num_requests_per_context = self.num_requests_per_context
    target_fraction_by_request_type = {
        cl.RequestType.NON_RULE: self.non_rule_fraction,
        cl.RequestType.RULE: 1 - self.non_rule_fraction
    }
    num_examples_by_type = collections.Counter()
    context_index = 0
    while True:
      target_num_examples = (context_index + 1) * num_requests_per_context
      target_num_examples_by_type = collections.Counter({
          key: target_num_examples * value
          for key, value in target_fraction_by_request_type.items()
      })

      num_new_examples_by_type_unrounded = (
          target_num_examples_by_type - num_examples_by_type)
      num_new_examples_by_type = immutabledict.immutabledict(
          zip(
              num_new_examples_by_type_unrounded.keys(),
              _round_while_preserving_sum(
                  num_new_examples_by_type_unrounded.values())))
      num_examples_by_type += collections.Counter(num_new_examples_by_type)
      yield num_new_examples_by_type
      context_index += 1


@enum.unique
class SplitBy(str, enum.Enum):
  """Represents different dataset splitting approaches."""
  EXAMPLE = 'EXAMPLE'
  CONTEXT = 'CONTEXT'
  HIGH_EXPLICIT_FRACTION_IN_TRAIN = 'HIGH_EXPLICIT_FRACTION_IN_TRAIN'
  LOW_EXPLICIT_FRACTION_IN_TRAIN = 'LOW_EXPLICIT_FRACTION_IN_TRAIN'
  HIGH_NUM_RULES_IN_TRAIN = 'HIGH_NUM_RULES_IN_TRAIN'
  LOW_NUM_RULES_IN_TRAIN = 'LOW_NUM_RULES_IN_TRAIN'
  CONTEXT_AND_OUTPUT_PATTERN = 'CONTEXT_AND_OUTPUT_PATTERN'
  CONTEXT_AND_OUTPUT_PATTERN_AND_RESHUFFLE = (
      'CONTEXT_AND_OUTPUT_PATTERN_AND_RESHUFFLE')
  SUBSAMPLE_AND_CONTEXT = 'SUBSAMPLE_AND_CONTEXT'
  COMPOUND_DIVERGENCE = 'COMPOUND_DIVERGENCE'


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class CompoundDivergenceOptions:
  """Options controlling dataset splitting by compound divergence.

  Attributes:
    use_rule_pattern: Whether to transform rule strings to rule patterns when
      forming compounds.
    max_compound_size: Maximum size of compounds (tuples of rules) to be
      considered.  Values from 2 up to max_compound_size (inclusive) will be
      used.
    top_level_example: If True, then will apply the compound divergence split at
      the level of individual top-level examples, which means that the content
      of the top-level example's request is taken into account, while no
      information from the top-level example's context is considered. If False,
      then will apply the compound divergence split at the level of contexts,
      which means that only the content of the context will be taken into
      account, not the content of the top-level example requests, and any given
      context will only appear in one of train or test, not both.
    filter_contexts: Whether to filter the contexts after splitting by top-level
      example compound divergence, so as to ensure that the same context does
      not appear in both train and test. This will lead to a reduction in the
      number of top-level examples in both the train and test set.
    composition_compound: If True, then uses compositions of atoms as compounds.
      If False (the default), then uses tuples of atoms as compounds. (Note:
      whether we look at the atoms and compositions of atoms from the top-level
      example's request or from the contents of the context examples depends on
      the value of `top_level_example`.)
    max_iterations: A maximum number of iterations (i.e. swap) to run for in the
      call to mcd_utils.swap_examples.
    max_divergence: If not None, will break if compound divergence exceeds this
      value in the call to mcd_utils.swap_examples.
    min_atom_count: Minimum amount of times an atom in examples_2 should appear
      in examples_1 in the call to mcd_utils.swap_examples.
    use_insertion_deletion: Whether to use the insertion/deletion-based
      algorithm for splitting.  If False (the default), use the swap-based
      algorithm.
    output_fraction: The fraction of generated contexts to be included in the
      splits.
    initial_fraction: The fraction of target size to be randomly included in the
      subsets before the insertion/deletion iteration.  This helps prevent the
      compound counters from being empty in the case of top-level example
      compound divergence split.
    delete_period: For every delete_period insertions, one item will be removed
      from a target subset.
    target_atom_divergence: The target atom divergence.
    sample_size: The number of candidate items to consider in each iteration.
    filter_items_for_missing_atom: Whether to filter items to only those
      containing a missing atom before sampling for candidate items for the
      iteration.
    atom_similarity_exponent: The exponent used to calculate the atom similarity
      factor in an item's adequacy.
    atom_coverage_exponent: The exponent used to calculate the atom coverage
      factor in an item's adequacy.
    two_stage_mcd: If True, the compound divergence maximization algorithm is
      applied twice, first creating the test split from train+validation, then
      creating the train split and the validation split.  If False, the split of
      validation from train+validation will not use the compound divergence
      maximization algorithm.
  """
  # Options for the definition of atoms and compounds.
  use_rule_pattern: bool = False
  max_compound_size: int = 3
  top_level_example: bool = False
  filter_contexts: bool = True
  composition_compound: bool = False

  # Options for mcd_utils' swap-based algorithm.
  max_iterations: int = 10000
  max_divergence: Optional[float] = 0.9
  min_atom_count: int = 1

  # Options for insertion/deletion-based algorithm.
  use_insertion_deletion: bool = False
  output_fraction: float = 1.0
  initial_fraction: float = 0.1
  delete_period: int = 3
  target_atom_divergence: float = 0.02
  sample_size: int = 200
  filter_items_for_missing_atom: bool = False
  atom_similarity_exponent: float = 4.0
  atom_coverage_exponent: float = 4.0
  two_stage_mcd: bool = True


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class SplitOptions:
  """Options controlling dataset splitting.

  Each option has a default value, so that a SplitOptions constructed with no
  arguments will lead to some sort of reasonable behavior.

  Attributes:
    test_fraction: Fraction of the dataset to hold out for the test set.
    validation_fraction: Fraction of the dataset to hold out for the validation
      set.
    split_by: How the dataset will be split into train/validation/test.
    subsample_size: The total number of top-level examples to keep. Only used if
      split_by is SUBSAMPLE_AND_CONTEXT.
    compound_divergence_options: Options for compound divergence split
      algorithm.  Only used if split_by is COMPOUND_DIVERGENCE.
  """

  test_fraction: float = 0.1
  validation_fraction: float = 0.1
  split_by: SplitBy = SplitBy.CONTEXT

  # Additional parameters for SUBSAMPLE_AND_CONTEXT.
  subsample_size: int = 0

  compound_divergence_options: CompoundDivergenceOptions = dataclasses.field(
      default_factory=CompoundDivergenceOptions)


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class GenerationOptions:
  """Bundle of options controlling the generation of a single benchmark dataset.

  Attributes:
    random_seed: The random seed to be used in the data generation process. If
      the GenerationOptions is being passed to a process that is charged with
      generating multiple dataset replicas, then this seed will be used to
      generate a series of random integers that will be used as the actual
      random seeds for each of the replicas. A value of 0 indicates that we use
      the current timestamp to generate a random seed, which means that the seed
      is not reproducible.
    random_seed_same_as: If several dataset specs are intended to share the same
      random seed, then one of the dataset specs in that group should be
      selected as the representative one, and all the other dataset specs should
      specify the representative dataset spec id here. Otherwise, a validation
      in the unit tests will enforce that each dataset spec has a unique random
      seed.
    grammar: Options controlling grammar generation.
    sampling: Options controlling random sampling of context and examples.
    splitting: Options controlling splitting of the dataset.
  """

  random_seed: int = 0
  random_seed_same_as: Optional[str] = None
  grammar: GrammarOptions = dataclasses.field(default_factory=GrammarOptions)
  sampling: SamplingOptions = dataclasses.field(default_factory=SamplingOptions)
  splitting: SplitOptions = dataclasses.field(default_factory=SplitOptions)

  def use_timestamp_for_random_seed(self):
    """Returns whether the time stamp should be used for the random seed."""
    return self.random_seed == 0

  def random_seed_or_timestamp(self):
    """Returns the random_seed, which may be based on the time stamp."""
    return ((time.time_ns() % 2**32)
            if self.use_timestamp_for_random_seed() else self.random_seed)


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class DatasetSpec:
  """Named specification of a single benchmark dataset to be generated.

  Multiple datasets are allowed to be created from the same DatasetSpec, either
  simultaneously (to account for different possible random seeds) or at later
  points in time (to reflect changes in the benchmark generation code).

  Attributes:
    id: Unique identifier of the dataset specification. Should be in a format
      compatible with use as a directory name in the benchmark output location.
    description: Free-form text describing the dataset.
    generation_options: Options controlling generation of the dataset.
    template_grammar_id:  A choice of standard grammar from which a
      GrammarSchema template will be derived that preserves the phrase
      structure, while allowing output sequences to be randomly generated. If
      specified, this will be used as the template for all grammar generation
      calls within the benchmark generation run. If not specified, then no
      GrammarSchema template will be used.
  """

  id: str
  description: str
  generation_options: GenerationOptions = dataclasses.field(
      default_factory=GenerationOptions)
  template_grammar_id: Optional[grammar_loader.StandardGrammarId] = None


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class DatasetSuiteSpec:
  """Named specification of a suite of benchmark datasets to be generated.

  Attributes:
    id: Unique identifier of the dataset suite specification. Should be in a
      format compatible with use as a directory name.
    description: Free-form text describing the dataset suite.
    dataset_specs: Ids of the DatasetSpecs to use.
  """

  id: str
  description: str
  dataset_specs: List[str] = dataclasses.field(default_factory=list)
