# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for generating conceptual learning datasets.

A conceptual learning dataset consists of a set of examples, each of which
contains a request and response, as well as a context and a qualifier indicating
whether the response is true monotonically or only defeasibly. The context is
itself represented as a set of examples and contains background knowledge that
directly or indirectly describes the rules of an underlying task of translating
request into response. The idea is that rather than having the entire dataset
deal with a single task, we can have each top-level example deal with a
different underlying task, which is described in the context of that example.

Specifically, this library focuses on generation of datasets satisfying the
following criteria:
* Each underlying task is describable as an nltk.grammar.FeatureGrammar -- or
  more specifically, as a GrammarSchema, which corresponds to a FeatureGrammar
  that is similar in structure to the original SCAN rule set of
  [Lake and Baroni 2018] (https://arxiv.org/pdf/1711.00350.pdf).
* The context of each example is only one level deep. That is, while each
  top-level example contains a non-empty context, the examples inside of the
  context all have empty contexts.

Dataset generation involving the following steps:
1. Generating large numbers of random FeatureGrammars, each of which represents
   an underlying task. This step is delegated to the GrammarGenerator class.
2. Generating random examples of the underlying task. This step is delegated to
   the ExampleGenerator class.
3. Generating top-level examples, each of which contains a context, and for
   which the request and reply are in one of two forms:
   (a) the request and response correspond to the input and output of the
       underlying task
   (b) the request is a rule and the response indicates whether the given rule
       is true in the underlying task
   This step is handled directly in this library.
"""

import collections
import copy
import dataclasses
import functools
import itertools
import logging
import time

from typing import (Callable, Dict, Iterable, Iterator, List, Mapping, Optional,
                    Sequence, Tuple, Union)

import attr
import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import grammar_generation
from conceptual_learning.cscan import grammar_schema
from conceptual_learning.cscan import induction
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import rule_conversion
from conceptual_learning.cscan import sampling
from conceptual_learning.cscan import stats_utils
from conceptual_learning.cscan import tokenization


def _add_unique_examples(
    dataset, examples,
    counters):
  """Adds the given examples to the dataset, while skipping duplicates.

  Args:
    dataset: Dataset to be added to.
    examples: Examples to add.
    counters: Counters tracking statistics about the generation process. These
      will be updated here as needed to reflect filtering of duplicates.

  Yields:
    The examples that were successfully added.
  """
  for example in examples:
    added_example = dataset.add_example(example)
    if (added_example is not example) or (added_example is None):
      logging.debug('Skipping duplicate request: %s', example.request)
      counters.example_attempts.duplicate += 1
      counters.example_attempts.valid -= 1
    else:
      yield example


def _filter_inconsistent_examples(
    examples, counters,
    inference_engine):
  """Filters inconsistent examples from the stream of examples.

  If the example's production is consistent with what is currently in the
  inference engine, it will be added into it.

  Args:
    examples: Examples to filter.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
    inference_engine: The inference engine for maintaining the context's
      consistency.

  Yields:
    The examples that could be added without causing inconsistency.
  """
  for example in examples:
    try:
      inference_engine.add_production(
          example.metadata.production, is_monotonic=True)
    except inference.InconsistencyError:
      logging.debug('Skipping inconsistent example with production: %s',
                    example.metadata.production)
      counters.example_attempts.context_example_inconsistent += 1
      continue

    yield example


def _filter_examples_illustrating_distractor_rules(
    examples, max_distractor_rule_illustration,
    context,
    counters):
  """Filters examples that illustrate distractor rules.

  This function filters out examples that if added to the context would allow
  at least one distractor rule to be illustrated too many times.  Currently the
  threshold is set to 2.

  Args:
    examples: Examples to filter.
    max_distractor_rule_illustration: Maximum number of different examples
      allowed to illustrate the same distractor rule.
    context: The mutable dataset the examples are to be added to.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.

  Yields:
    The examples that could be added without illustrating any distractor rule
    too many times.
  """
  for example in examples:
    should_yield = True
    for distractor_rule in example.metadata.iter_distractor_rules():
      examples_using_distractor_rule = (
          context.metadata.examples_by_rule.get(distractor_rule, []))
      if (len(examples_using_distractor_rule) >=
          max_distractor_rule_illustration):
        should_yield = False
        (counters.example_attempts.illustrating_distractor_rule_too_many_times
        ) += 1
        break

    if should_yield:
      yield example


def _generate_and_add_illustrative_examples_to_context(
    simple_example_generator,
    target_num,
    target_rule,
    mutable_context,
    options,
    counters,
    inference_engine = None
):
  """Generates illustrative examples for the target_rule and add to context.

  Examples which are inconsistent or are duplicates will not be added.  Examples
  which are consistent are added to the inference engine.

  Args:
    simple_example_generator: Example generator based on the given grammar. Used
      for generating examples of the underlying task.
    target_num: The target number of illustrative examples for the target rule.
    target_rule: The target rule to illustrative.
    mutable_context: The context to add the examples to.
    options: Bundle of options controlling the generation algorithm.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
    inference_engine: The inference engine for maintaining the context's
      consistency.

  Returns:
    The list of examples added to the context.
  """

  # The callable `example_generation_func` should have the same signature as
  # ExampleGenerator.generate_n_non_rule_examples and
  # ExampleGenerator.generate_n_derived_rule_examples.
  def _call_example_generation_func(example_generation_func, target_num):
    max_attempts = options.max_attempts_per_example * target_num
    candidate_examples = example_generation_func(
        max_attempts,
        target_rule,
        rules_to_avoid_as_dependency=mutable_context.metadata.omitted_rules,
        unreliable_rules=mutable_context.metadata.unreliable_rules)

    max_distractor_rule_illustration = options.num_examples_per_hidden_rule - 2
    candidate_examples = _filter_examples_illustrating_distractor_rules(
        candidate_examples, max_distractor_rule_illustration, mutable_context,
        counters)

    # In the call to _filter_inconsistent_examples, the examples' productions
    # are added to inference_engine.  In order to use inference_engine to sample
    # top-level examples later on, we need to keep the source productions of
    # inference_engine exactly as those productions in the context, so this
    # filter should be done right before adding the examples to the context.
    candidate_examples = _filter_inconsistent_examples(candidate_examples,
                                                       counters,
                                                       inference_engine)
    added_examples = tuple(
        itertools.islice(
            _add_unique_examples(mutable_context, candidate_examples, counters),
            target_num))
    return added_examples

  target_num_non_rule = int(target_num *
                            options.illustrative_example_non_rule_fraction)
  target_num_derived_rule = target_num - target_num_non_rule
  target_num_by_example_generation_func = {
      simple_example_generator.generate_n_non_rule_examples:
          target_num_non_rule,
      simple_example_generator.generate_n_derived_rule_examples:
          target_num_derived_rule
  }

  illustrative_examples = []
  for example_generation_func, inner_target_num in (
      target_num_by_example_generation_func.items()):

    added_examples = _call_example_generation_func(
        example_generation_func=example_generation_func,
        target_num=inner_target_num)
    illustrative_examples.extend(added_examples)
  return illustrative_examples


def _sample_omitted_explicit_unreliable_fractions(
    num_rules, options,
    rng):
  """Samples a valid tuple of sampled fractions of the rule types.

  The current implementation samples from a clipped normal distribution
  truncated in a range larger than [0, 1] so that the two buckets at the end
  have the same mass.

  The tuple of fractions are considered valid if their sum is less than or
  equal to 1.

  Args:
    num_rules: The number of rules from which the fractions will be used to
      sample omitted/explicit rules.
    options: Bundle of options controlling the generation algorithm.
    rng: Random number generator.

  Returns:
    A tuple of (omitted_fraction, explicit_fraction, unreliable_fraction).
  """
  width = 1.0 / num_rules
  left = 0.0 - width / 2
  right = 1.0 + width / 2
  while True:
    omitted_fraction = stats_utils.sample_clipped_truncated_normal(
        left, right, options.omitted_fraction, options.omitted_fraction_stddev,
        rng)
    explicit_fraction = max(
        options.min_explicit_fraction,
        stats_utils.sample_clipped_truncated_normal(
            left, right, options.explicit_fraction,
            options.explicit_fraction_stddev, rng))
    unreliable_fraction = stats_utils.sample_clipped_truncated_normal(
        left, right, options.unreliable_fraction,
        options.unreliable_fraction_stddev, rng)
    total = omitted_fraction + explicit_fraction + unreliable_fraction
    if total > 1:
      omitted_fraction /= total
      explicit_fraction /= total
      unreliable_fraction /= total
    return omitted_fraction, explicit_fraction, unreliable_fraction


def _adjust_context_and_inference_engine_based_on_inductive_bias(
    original_context,
    original_engine,
    inductive_bias
):
  """Returns adjusted context and engine consistent with the inductive bias .

  Each hidden rule is checked against the inductive bias and is treated as
  either "hidden_true" or "hidden_unknown" depending on whether the inductive
  bias was satisfied.

  Does not modify the original context or inference_engine.

  Args:
    original_context: The original context, in which all hidden rules were
      initially treated as if they were defeasibly true, regardless of whether
      they actually satisfied the inductive bias.
    original_engine: The original inference engine, in which all hidden rules
      were initially treated as if they were defeasibly true, regardless of
      whether they actually satisfied the inductive bias.
    inductive_bias: The inductive bias to use in determining which hidden rules
      are actually true and which ones are unknown.
  """
  # Adjust the context to reflect the truth value of each hidden rule in the
  # context metadata.
  rule_reply_by_hidden_rule = {}
  for rule in original_context.metadata.hidden_rules:
    rule_reply = (
        cl.RuleReply.TRUE if inductive_bias.can_induce_rule(
            rule, original_context) else cl.RuleReply.UNKNOWN)
    rule_reply_by_hidden_rule[rule] = rule_reply
  new_context = attr.evolve(
      original_context,
      metadata=attr.evolve(
          original_context.metadata,
          rule_reply_by_hidden_rule=rule_reply_by_hidden_rule))

  if new_context.metadata.hidden_unknown_rules:
    # Adjust the inference engine to keep the same monotonic contents as before,
    # but in the defeasible portion only include the hidden rules that actually
    # satisfy the inductive bias.
    new_engine = original_engine.copy_monotonic_engine()
    for production in original_engine.source_productions:
      rule = rule_conversion.rule_from_production(
          production, original_context.metadata.rule_format)
      if (rule in original_context.metadata.hidden_rules and
          inductive_bias.can_induce_rule(rule, original_context)):
        new_engine.add_production(production, False)
  else:
    # Can use the original engine as-is, as all of the hidden rules satisfied
    # the inductive bias.
    new_engine = original_engine

  return new_context, new_engine


def _generate_context_and_inference_engine(
    simple_example_generator,
    options, counters,
    rng
):
  """Returns a generated context and an inference engine for the context.

  The context consists of a set of Examples that illustrate each of the rules in
  the given simple example generator either directly or indirectly. This
  function first randomly assign rules as omitted, explicit, unreliable, or
  hidden.  Omitted rules are not illustrated in the context in any way.
  Explicit rules are illustrated by at least one example that directly asserts
  the rule to be true.  For unreliable rules and hidden rules, some number of
  examples of the underlying task are added that indirectly illustrate the
  behavior of the given rule.  Whenever an unreliable rule is used in the
  generation of any example, it is randomly replaced with a distractor rule.

  The inference engine contains all the productions of illustrating examples in
  its monotonic_productions collection, and the hidden rules' productions in its
  all_productions collection.  All the productions that can be inferred by
  composing the productions are also in all_productions.  Context generation is
  retried if inconsistency occurs.

  Args:
    simple_example_generator: Example generator based on the given grammar. Used
      for generating examples of the underlying task.
    options: Bundle of options controlling the generation algorithm.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
    rng: Random number generator.

  Raises:
    MaxAttemptsReachedError: If a context passing _context_quality_check is not
      generated after options.max_attempts_per_context attempts.
  """
  rules = simple_example_generator.get_rules()
  rng.shuffle(rules)

  omitted_fraction, explicit_fraction, unreliable_fraction = (
      _sample_omitted_explicit_unreliable_fractions(len(rules), options, rng))
  num_omitted_rules = round(len(rules) * omitted_fraction)
  num_explicit_rules = round(len(rules) * explicit_fraction)
  num_unreliable_rules = round(len(rules) * unreliable_fraction)

  start_omitted = 0
  end_omitted = start_explicit = start_omitted + num_omitted_rules
  end_explicit = start_unreliable = start_explicit + num_explicit_rules
  end_unreliable = start_hidden = start_unreliable + num_unreliable_rules

  omitted_rules = rules[start_omitted:end_omitted]
  explicit_rules = rules[start_explicit:end_explicit]
  unreliable_rules = rules[start_unreliable:end_unreliable]
  hidden_rules = rules[start_hidden:]

  # In case we need multiple attempts to generate a context of good quality, we
  # keep a copy of the inference engine that has all the explicit rules, hidden
  # rules, and pass through rules already added so in subsequent attempts it can
  # be reused.
  base_inference_engine = None

  # We start the loop for context generation and context quality check only
  # after the assignment of rules as omitted/explicit/unreliable/hidden has
  # been determined, so that dropping low quality contexts will have the least
  # effect on the distribution.
  for attempt_index in range(1, options.max_attempts_per_context + 1):
    if base_inference_engine is None:
      logging.info(
          'Building base inference engine for %d/%d/%d/%d '
          'omitted/explicit/unreliable/hidden rules', len(omitted_rules),
          len(explicit_rules), len(unreliable_rules), len(hidden_rules))
      logging.info('omitted rules: %s', omitted_rules)
      logging.info('explicit rules: %s', explicit_rules)
      logging.info('unreliable rules: %s', unreliable_rules)
      logging.info('hidden rules: %s', hidden_rules)
      base_inference_engine = inference.InferenceEngine(
          provenance_by_production=(
              simple_example_generator.provenance_by_production))
      # We include all the pass-through rule productions, although they do not
      # appear as examples' productions, as they are needed for calculating
      # derived productions.
      for production in (
          simple_example_generator.get_productions_without_rules()):
        base_inference_engine.add_production(production, is_monotonic=True)

      for rule in explicit_rules:
        logging.debug(
            'Adding explicit rule to inference engine as monotonic production: '
            '%s', rule)
        for production in (
            simple_example_generator.get_productions_from_rule(rule)):
          base_inference_engine.add_production(production, is_monotonic=True)

      # Hidden rules are added as defeasible productions.
      for rule in hidden_rules:
        logging.debug(
            'Adding hidden rule to inference engine as defeasible production: '
            '%s', rule)
        for production in (
            simple_example_generator.get_productions_from_rule(rule)):
          base_inference_engine.add_production(production)

    logging.info('Generating context (attempt %d of %d)', attempt_index,
                 options.max_attempts_per_context)
    inference_engine = base_inference_engine.backup_states()

    mutable_context = cl.ExampleSet()
    mutable_context.metadata.rule_format = options.rule_format
    mutable_context.metadata.grammar = simple_example_generator.get_grammar()
    for rule in omitted_rules:
      logging.debug('Adding omitted rule to context: %s', rule)
      mutable_context.add_omitted_rule(rule)

    for rule in explicit_rules:
      logging.debug('Adding explicit rule to context: %s', rule)
      example = simple_example_generator.get_example_for_explicit_rule(rule)
      mutable_context.add_explicit_rule(rule, example)

    # We need to first mark the rules as unreliable, since it changes the
    # example generation process for all unreliable and hidden rules.
    for rule in unreliable_rules:
      logging.debug('Marking rule as unreliable: %s', rule)
      mutable_context.mark_rule_as_unreliable(rule)

    # We shuffle the unreliable and hidden rules together before adding their
    # illustrative examples to the context to avoid leaking any information
    # about which rules are unreliable based on the position of their examples.
    unreliable_rules_set = set(unreliable_rules)
    illustrated_rules = list(
        itertools.chain.from_iterable([unreliable_rules, hidden_rules]))
    rng.shuffle(illustrated_rules)
    for rule in illustrated_rules:
      is_unreliable = rule in unreliable_rules_set
      if is_unreliable:
        logging.debug('Illustrating unreliable rule via examples: %s', rule)
      else:
        logging.debug(
            'Hiding rule from context and illustrating via examples: %s', rule)

      # The illustrative examples are checked for consistency and added to the
      # inference engine in the _filter_inconsistent_examples function, called
      # by _generate_and_add_illustrative_examples_to_context.
      illustrative_examples = (
          _generate_and_add_illustrative_examples_to_context(
              simple_example_generator,
              target_num=options.num_examples_per_hidden_rule,
              target_rule=rule,
              mutable_context=mutable_context,
              options=options,
              counters=counters,
              inference_engine=inference_engine))

      if is_unreliable:
        mutable_context.add_unreliable_rule(rule, illustrative_examples)
      else:
        mutable_context.add_hidden_rule(rule, illustrative_examples)

      if len(illustrative_examples) < options.num_examples_per_hidden_rule:
        logging.warning(
            'Unable to generate the requested number of illustrative examples '
            'for the given rule: %d < %d, rule = %s',
            len(illustrative_examples), options.num_examples_per_hidden_rule,
            rule)
        counters.errors.failed_to_illustrate_target_rule += 1

    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    if (_context_length_check(context, options, counters) and
        _context_quality_check(context, options, counters)):
      return _adjust_context_and_inference_engine_based_on_inductive_bias(
          context, inference_engine, options.inductive_bias)

  raise sampling.MaxAttemptsReachedError('Failed to generate context of good '
                                         'illustration quality.')


def _filter_requests_already_in_context(
    examples, context,
    counters):
  for example in examples:
    if context.request_already_in_example_set(example.request):
      logging.debug('Discarded request already in context: %s', example)
      counters.example_attempts.already_in_context += 1
      counters.example_attempts.valid -= 1
    else:
      yield example


def _get_target_fraction_by_example_type(
    request_type, options,
    context,
    rng):
  """Returns target fractions by example type."""
  unknown_fraction = stats_utils.sample_clipped_truncated_normal(
      left=0,
      right=1,
      mean=options.unknown_example_fraction,
      std=options.unknown_example_fraction_stddev,
      rng=rng)

  defeasible_fraction = stats_utils.sample_clipped_truncated_normal(
      left=0,
      right=1,
      mean=options.defeasible_example_fraction,
      std=options.defeasible_example_fraction_stddev,
      rng=rng)

  # We override the fractions if the context is missing certain kinds of rules
  # to reduce time wasted on example types that are impossible.
  if not context.metadata.hidden_rules:
    defeasible_fraction = 0.0
  if (not context.metadata.omitted_rules and
      not context.metadata.unreliable_rules):
    unknown_fraction = 0.0

  known_fraction = 1.0 - unknown_fraction
  monotonic_fraction = 1.0 - defeasible_fraction

  if request_type == cl.RequestType.NON_RULE:
    target_fraction_by_example_type = {
        cl.ExampleType.NONRULE_KNOWN_D: known_fraction * defeasible_fraction,
        cl.ExampleType.NONRULE_KNOWN_M: known_fraction * monotonic_fraction,
        cl.ExampleType.NONRULE_UNKNOWN_D: unknown_fraction
    }
  else:
    negative_fraction = stats_utils.sample_clipped_truncated_normal(
        left=0,
        right=1,
        mean=options.negative_example_fraction,
        std=options.negative_example_fraction_stddev,
        rng=rng)
    positive_fraction = 1.0 - negative_fraction

    target_fraction_by_example_type = {
        cl.ExampleType.RULE_KNOWN_TRUE_D:
            known_fraction * positive_fraction * defeasible_fraction,
        cl.ExampleType.RULE_KNOWN_TRUE_M:
            known_fraction * positive_fraction * monotonic_fraction,
        cl.ExampleType.RULE_KNOWN_FALSE_D:
            known_fraction * negative_fraction * defeasible_fraction,
        cl.ExampleType.RULE_KNOWN_FALSE_M:
            known_fraction * negative_fraction * monotonic_fraction,
        cl.ExampleType.RULE_UNKNOWN_D:
            unknown_fraction
    }

  return target_fraction_by_example_type


def _merge_example_iterators_using_target_fraction_by_example_type(
    request_type, examples_by_type,
    target_fraction_by_example_type,
    counters):
  """Yields examples from the stream according to target fractions."""

  def sorted_example_types():
    """Returns example types sorted by decreasing target fraction deficit."""
    current_fraction_by_example_type_by_request_type = (
        counters.examples.get_fraction_by_example_type_by_request_type())
    current_fraction_by_example_type = (
        current_fraction_by_example_type_by_request_type[request_type])
    deficit_by_example_type = {}
    for example_type, target_fraction in (
        target_fraction_by_example_type.items()):
      deficit_by_example_type[example_type] = (
          target_fraction - current_fraction_by_example_type[example_type])

    return sorted(
        deficit_by_example_type,
        key=lambda example_type: deficit_by_example_type[example_type],
        reverse=True)

  # If the target fraction of an example type is 0.0, we do not use the example
  # stream of that type at all.
  available_example_types = set(
      example_type for example_type in examples_by_type
      if target_fraction_by_example_type[example_type] > 0.0)
  while available_example_types:
    selected_example_type = next(
        filter(lambda x: x in available_example_types, sorted_example_types()))
    selected_stream = examples_by_type[selected_example_type]
    try:
      # An alternative of having multiple examples treams one for each example
      # type if to get examples from a single stream until the currently
      # requested type appears, but keeping all the unused examples in separate
      # buffers and yield from them when needed.
      yield next(selected_stream)
    except StopIteration:
      # If a stream of examples has been exhausted, we move on to the next
      # available stream with the largest deficit instead of giving up.
      available_example_types.remove(selected_example_type)


def _populate_input_output_lengths(
    examples, context,
    enable_remote_dependencies):
  """Yields equivalent examples with input and output lengths populated.

  Args:
    examples: Top-level examples for which to calculate input/output lengths.
    context: The context that applies to these top-level examples. (Typically,
      this will not yet be populated inside of the example objects themselves,
      due to use of the GroupedExampleSet representation.)
    enable_remote_dependencies: Whether to enable dependencies on remote
      services such as the T5X tokenizer. (See explanation in generate_dataset.)
  """
  for example in examples:
    example_with_context = attr.evolve(example, context=context)
    evolved_metadata = attr.evolve(
        example.metadata,
        input_length_standard=tokenization.get_input_length(
            example_with_context, tokenization.ExampleStringFormat.STANDARD),
        output_length_standard=tokenization.get_output_length(
            example_with_context, tokenization.ExampleStringFormat.STANDARD))
    if enable_remote_dependencies:
      evolved_metadata = attr.evolve(
          evolved_metadata,
          input_length_compact=tokenization.get_input_length(
              example_with_context, tokenization.ExampleStringFormat.COMPACT),
          output_length_compact=tokenization.get_output_length(
              example_with_context, tokenization.ExampleStringFormat.COMPACT))
    yield attr.evolve(example, metadata=evolved_metadata)


def _filter_examples_exceeding_max_input_output_lengths(
    examples, options,
    counters):
  """Yields just the examples with acceptable input and output lengths."""
  for example in examples:
    if (options.lengths.max_input_length_compact and
        (example.metadata.input_length_compact >
         options.lengths.max_input_length_compact)):
      logging.debug('Discarded example exceeding max input length: %s', example)
      counters.example_attempts.exceeded_max_input_length += 1
      counters.example_attempts.valid -= 1
      continue

    if (options.lengths.max_input_length_standard and
        (example.metadata.input_length_standard >
         options.lengths.max_input_length_standard)):
      logging.debug('Discarded example exceeding max input length: %s', example)
      counters.example_attempts.exceeded_max_input_length += 1
      counters.example_attempts.valid -= 1
      continue

    if (options.lengths.max_output_length_compact and
        (example.metadata.output_length_compact >
         options.lengths.max_output_length_compact)):
      logging.debug('Discarded example exceeding max output length: %s',
                    example)
      counters.example_attempts.exceeded_max_output_length += 1
      counters.example_attempts.valid -= 1
      continue

    if (options.lengths.max_output_length_standard and
        (example.metadata.output_length_standard >
         options.lengths.max_output_length_standard)):
      logging.debug('Discarded example exceeding max output length: %s',
                    example)
      counters.example_attempts.exceeded_max_output_length += 1
      counters.example_attempts.valid -= 1
      continue

    yield example


def _generate_and_add_top_level_examples_of_request_type(
    dataset, request_type, num_examples,
    simple_example_generator,
    context, inference_engine,
    options, counters,
    rng, enable_remote_dependencies):
  """Generates examples of the given type and adds them to the dataset.

  For each request type the function will try to maintain the requested target
  distribution of different example attributes:
  - For RULE examples: target distribution of qualifiers and unknown reply.
  - For NON_RULE examples: target distribution of replies and qualifiers.

  The implementation consists of first creating infinite streams of examples for
  each controlled attribute, then choosing which stream to yield next according
  to the examples that have been added so far.

  Args:
    dataset: Dataset to be added to.
    request_type: The request type to generate.
    num_examples: Number of examples to generate.
    simple_example_generator: Example generator based on the given grammar. Used
      for generating examples of the underlying task.
    context: A context to be applied to each of the examples.
    inference_engine: The inference engine to be used to calculate replies and
      qualifiers.  It should be the InferenceEngine instance generated along
      with the context.
    options: Bundle of options controlling the generation algorithm.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
    rng: Random number generator.
    enable_remote_dependencies: Whether to enable dependencies on remote
      services such as the T5X tokenizer. (See explanation in generate_dataset.)
  """
  max_attempts = num_examples * options.max_attempts_per_example
  target_fraction_by_example_type = _get_target_fraction_by_example_type(
      request_type, options, context, rng)
  if request_type == cl.RequestType.NON_RULE:
    non_rule_example_generator_func = functools.partial(
        simple_example_generator.generate_n_non_rule_examples_with_qualifier,
        n=max_attempts,
        unreliable_rules=context.metadata.unreliable_rules,
        inference_engine=inference_engine,
        context=context)
    examples_by_type = {
        cl.ExampleType.NONRULE_UNKNOWN_D:
            non_rule_example_generator_func(
                unknown_reply=True, qualifier=cl.Qualifier.D),
        cl.ExampleType.NONRULE_KNOWN_M:
            non_rule_example_generator_func(
                unknown_reply=False, qualifier=cl.Qualifier.M),
        cl.ExampleType.NONRULE_KNOWN_D:
            non_rule_example_generator_func(
                unknown_reply=False, qualifier=cl.Qualifier.D)
    }

  else:
    rule_example_generator_func = functools.partial(
        simple_example_generator
        .generate_n_rule_examples_with_reply_and_qualifier,
        n=max_attempts,
        rules_to_avoid=context.metadata.explicit_rules,
        unreliable_rules=context.metadata.unreliable_rules,
        inference_engine=inference_engine,
        context=context)
    examples_by_type = {
        cl.ExampleType.RULE_UNKNOWN_D:
            rule_example_generator_func(
                rule_reply=cl.RuleReply.UNKNOWN, qualifier=cl.Qualifier.D),
        cl.ExampleType.RULE_KNOWN_FALSE_D:
            rule_example_generator_func(
                rule_reply=cl.RuleReply.FALSE, qualifier=cl.Qualifier.D),
        cl.ExampleType.RULE_KNOWN_FALSE_M:
            rule_example_generator_func(
                rule_reply=cl.RuleReply.FALSE, qualifier=cl.Qualifier.M),
        cl.ExampleType.RULE_KNOWN_TRUE_D:
            rule_example_generator_func(
                rule_reply=cl.RuleReply.TRUE, qualifier=cl.Qualifier.D),
        cl.ExampleType.RULE_KNOWN_TRUE_M:
            rule_example_generator_func(
                rule_reply=cl.RuleReply.TRUE, qualifier=cl.Qualifier.M),
    }

  candidate_examples = (
      _merge_example_iterators_using_target_fraction_by_example_type(
          request_type, examples_by_type, target_fraction_by_example_type,
          counters))

  candidate_examples = _filter_requests_already_in_context(
      candidate_examples, context, counters)

  # We populate input/output lengths here rather than earlier in sampling.py
  # so that we can be sure to take into account the context length.
  candidate_examples = _populate_input_output_lengths(
      candidate_examples, context, enable_remote_dependencies)

  candidate_examples = _filter_examples_exceeding_max_input_output_lengths(
      candidate_examples, options, counters)

  examples_added = []
  generation_start_time = time.time()
  for example in itertools.islice(
      _add_unique_examples(dataset, candidate_examples, counters),
      num_examples):
    # We update the counters right after adding each example to the dataset here
    # so the counts can be used to decide which example stream to take from next
    # in _merge_example_iterators_using_target_fraction_by_example_type.
    counters.examples.update_with_example_and_context(example, context)
    examples_added.append(example)
    time_now = time.time()
    elapsed_time = time_now - generation_start_time
    if elapsed_time > options.max_generation_time_per_request_type:
      logging.warning(
          'Reached time out limit for generating examples of request type %s.',
          request_type.name)
      break

  if len(examples_added) < num_examples:
    logging.warning(
        'Unable to generate the requested number of requests of type %s '
        'for the given context: %d < %d', request_type.name,
        len(examples_added), num_examples)
    counters.errors.failed_to_generate_example_of_desired_request_type += 1


def _context_length_check_for_string_format(
    context, options,
    string_format):
  """Returns whether the context is of acceptable length in the given format."""
  if string_format == tokenization.ExampleStringFormat.STANDARD:
    max_input_length = options.lengths.max_input_length_standard
  elif string_format == tokenization.ExampleStringFormat.COMPACT:
    max_input_length = options.lengths.max_input_length_compact
  else:
    raise ValueError(f'Unknown example string format: {string_format}')

  if not max_input_length:
    return True

  context_length = tokenization.get_tokenized_length(
      tokenization.get_context_string(context, string_format), string_format)
  return context_length <= 0.95 * max_input_length


def _context_length_check(context,
                          options,
                          counters):
  """Returns whether the context is of acceptable tokenized length.

  In order to avoid the case where a large number of top-level examples later
  get discarded due to the context alone using up too much of the allowed input
  length, we consider a context to be of acceptable length only if it is less
  than or equal to 95% of the maximum allowed input length.

  Args:
    context: The context to check.
    options: Bundle of options controlling the generation algorithm.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
  """
  check_passed = all(
      _context_length_check_for_string_format(context, options, string_format)
      for string_format in (tokenization.ExampleStringFormat.STANDARD,
                            tokenization.ExampleStringFormat.COMPACT))
  if not check_passed:
    counters.context_attempts.exceeded_max_input_length += 1
  return check_passed


def _context_quality_check(context,
                           options,
                           counters):
  """Returns whether the context is of acceptable illustration quality.

  Currently two checks are implemented:
  - Every hidden rule should have GOOD illustration quality.
    (Optional check, depending on the SamplingOptions.)
  - Every unreliable rule should be used as at least two different distractor
    rules.

  The following requirement is not explicitly checked here, but is expected to
  be enforced in _filter_examples_illustrating_distractor_rules:
  - Every distractor rule should be used in no more than two different context
    examples.

  Args:
    context: The context to check.
    options: Bundle of options controlling the generation algorithm.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
  """
  check_passed = True
  if options.require_hidden_rules_to_satisfy_inductive_bias:
    hidden_rules_have_good_illustration_quality = all(
        induction.get_rule_illustration_quality(rule, context,
                                                options.inductive_bias) ==
        cl.IllustrationQuality.GOOD for rule in context.metadata.hidden_rules)
    if not hidden_rules_have_good_illustration_quality:
      counters.context_attempts.poor_illustration_quality += 1
      check_passed = False

  distinct_distractor_rules_by_unreliable_rule = collections.defaultdict(set)
  for example in context:
    for unreliable_rule, distractor_rules in (
        example.metadata.distractor_rules_by_unreliable_rule.items()):
      distinct_distractor_rules_by_unreliable_rule[unreliable_rule] = (
          distinct_distractor_rules_by_unreliable_rule[unreliable_rule].union(
              distractor_rules))

  unreliable_rules_are_illustrated_multiple_ways = all(
      len(distinct_distractor_rules) > 1 for distinct_distractor_rules in
      distinct_distractor_rules_by_unreliable_rule.values())

  if not unreliable_rules_are_illustrated_multiple_ways:
    counters.context_attempts.unreliable_rule_illustrated_only_one_way += 1
    check_passed = False

  return check_passed


def _generate_infinite_contexts_and_inference_engines_and_example_generators(
    grammar_generator,
    options,
    rng,
    counters,
    template = None,
):
  """Yields an infinite stream of contexts and simple example generators.

  The contexts are generated with the example generators yielded together and
  pass consistency and illustration quality checks.

  Args:
    grammar_generator: Grammar generator used to generate the grammar used by
      the example generator.
    options: Bundle of options controlling the generation algorithm.
    rng: Random number generator.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
    template: A GrammarSchema to be used as the template for generating the
      grammars that underly all contexts of the generated dataset.

  Raises:
    MaxAttemptsReachedError: If a grammar could not be generated due to reaching
      the maximum number of attempts.
  """
  if template is None and options.num_rules >= 0:
    raise ValueError(
        'Cannot limit number of rules without specifying a grammar template.')
  consecutive_failures = 0
  while True:
    actual_template = template
    if template is not None and options.num_rules >= 0:
      actual_template = copy.deepcopy(template)
      actual_template.sample_rules(options=options, rng=rng)
    grammar = grammar_generator.generate_grammar(actual_template)
    logging.info('Generated a random grammar: %s', grammar)
    # We track all production provenances for each context with a shared
    # instance of ProductionProvenanceDict.
    provenance_by_production = production_composition.ProductionProvenanceDict()

    simple_example_generator = sampling.ExampleGenerator(
        grammar=grammar,
        rng=rng,
        grammar_generator=grammar_generator,
        options=options,
        counters=counters,
        provenance_by_production=provenance_by_production)

    try:
      context, inference_engine = _generate_context_and_inference_engine(
          simple_example_generator, options, counters, rng)
    except sampling.MaxAttemptsReachedError as e:
      counters.errors.failed_to_generate_context += 1
      consecutive_failures += 1
      logging.info('Failed to generate grammar. Consecutive_failures = %d',
                   consecutive_failures)
      if consecutive_failures >= options.max_attempts_per_grammar:
        raise sampling.MaxAttemptsReachedError(
            f'Failed to generate a grammar due to reaching maximum number of '
            f'attempts: counters={counters}') from e
      continue

    consecutive_failures = 0
    yield context, inference_engine, simple_example_generator


def populate_example_group_with_top_level_examples(
    example_group,
    num_examples_by_type,
    simple_example_generator,
    options,
    counters,
    rng,
    inference_engine,
    enable_remote_dependencies = True
    ):
  """Populates an example group with top level examples."""
  for request_type, num_examples in num_examples_by_type.items():
    _generate_and_add_top_level_examples_of_request_type(
        dataset=example_group,
        request_type=request_type,
        num_examples=num_examples,
        simple_example_generator=simple_example_generator,
        context=example_group.context,
        inference_engine=inference_engine,
        options=options,
        counters=counters,
        rng=rng,
        enable_remote_dependencies=enable_remote_dependencies)


def _generate_dataset_content(
    options,
    counters,
    rng,
    template = None,
    enable_remote_dependencies = False):
  """Returns example groups from which a GroupledExampleSet can be built.

  This does the bulk of the work of `generate_dataset`, except for the final
  layer of exception handling. See `generate_dataset` for details.

  Args:
    options: See `generate_dataset`.
    counters:  See `generate_dataset`.
    rng:  See `generate_dataset`.
    template:  See `generate_dataset`.
    enable_remote_dependencies:  See `generate_dataset`.

  Raises:
    MaxAttemptsReachedError: If a dataset generation was aborted due to reaching
      the maximum number of attempts for generating a grammar.
  """
  grammar_generator = grammar_generation.GrammarGenerator(
      options=options.grammar, rng=rng)
  example_groups = []

  if (not enable_remote_dependencies and
      (options.sampling.lengths.max_input_length_compact or
       options.sampling.lengths.max_input_length_compact)):
    logging.info(
        'Disabling max_[input|output]_length_compact as remote dependencies '
        'are disabled, and calculating lengths for the COMPACT string '
        'representation would require a remote dependency for loading the T5X '
        'SentencePiece tokenizer.')
    options = dataclasses.replace(
        options,
        sampling=dataclasses.replace(
            options.sampling,
            lengths=dataclasses.replace(
                options.sampling.lengths,
                max_input_length_compact=0,
                max_output_length_compact=0)))
    logging.info('Adjusted generation options: %s', options)

  num_examples_by_type_schedule = (
      options.sampling.calculate_schedule_of_examples_by_type())
  contexts_and_inference_engine_and_example_generators = (
      _generate_infinite_contexts_and_inference_engines_and_example_generators(
          grammar_generator, options.sampling, rng, counters, template))
  for (context, inference_engine,
       simple_example_generator), num_examples_by_type in zip(
           contexts_and_inference_engine_and_example_generators,
           num_examples_by_type_schedule):
    example_group = cl.ExampleGroup(context=context)
    populate_example_group_with_top_level_examples(
        num_examples_by_type=num_examples_by_type,
        example_group=example_group,
        simple_example_generator=simple_example_generator,
        inference_engine=inference_engine,
        options=options.sampling,
        counters=counters,
        rng=rng,
        enable_remote_dependencies=enable_remote_dependencies)

    counters.rules.update_with_context(context)
    counters.contexts.update_with_context(context)
    counters.contexts.update_with_example_group(example_group)
    for context_example in context:
      counters.context_examples.update_with_example_and_context(context_example)

    example_group.shuffle(rng)
    example_groups.append(example_group)
    if len(example_groups) == options.sampling.num_contexts:
      break

  return example_groups


def generate_dataset(
    options,
    counters,
    rng,
    template = None,
    enable_remote_dependencies = False):
  """Generates set of examples for Conceptual SCAN.

  Generation is based on random sampling and consists of three main steps:
  1. randomly generating a SCAN-like grammar to form the basis of the "context"
  2. randomly generating the "context", which is itself a set of examples, which
     either directly or indirectly illustrate the rules of the given grammar
  3. randomly generating examples using that context that either illustrate the
     behavior of the end-to-end task (i.e., " non-rule requests") or test
     the understanding of what rules hold in that context (i.e., rule requests)

  Args:
    options: Bundle of options controlling the generation algorithm.
    counters: Counters tracking statistics about the generation process, which
      will be updated here where appropriate.
    rng: Random number generator.
    template: A GrammarSchema to be used as the template for generating the
      grammars that underly all contexts of the generated dataset.
    enable_remote_dependencies: Whether to enable dependencies on remote
      services such as the T5X tokenizer. If False, then features that depend on
      such services (such as the tracking of example input and output lengths in
      COMPACT format using the T5X tokenizer) will be disabled, but dataset
      generation will proceed normally otherwise. Should generally be set to
      False in unit tests, where external network dependencies are normally not
      allowed.

  Returns:
    A newly created GroupedExampleSet.
  """
  try:
    example_groups = _generate_dataset_content(options, counters, rng, template,
                                               enable_remote_dependencies)
  except sampling.MaxAttemptsReachedError:
    counters.errors.failed_to_generate_grammar += 1
    example_groups = []

  dataset = cl.GroupedExampleSet(example_groups=example_groups)

  return dataset
