# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for sampling conceptual learning examples from a FeatureGrammar.

Example usage 1 (10 random examples):
  counters = outputs.GenerationCounters()
  example_generator = ExampleGenerator(
      grammar=nltk.grammar.FeatureGrammar.fromstring(...), counters=counters)
  examples = list(example_generator.generate_n_examples(10))

Example usage 2 (10 examples illustrating a given grammar rule):
  counters = outputs.GenerationCounters()
  example_generator = ExampleGenerator(
      nltk.grammar.FeatureGrammar.fromstring(...), counters=counters))
  examples = list(example_generator.generate_n_examples(10,
      target_rule="S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"))

This library takes inspiration from the CFG-based generation library in
nltk/parse/generate.py. Compared to that existing library, this library has the
following differences:
- Generates full conceptual learning examples, which include both input and
  output, rather than just input sentences. Assumes the use of a special feature
  tag called 'sem' (as in 'semantics') for representing the output.
- Generates examples via random sampling rather than exhaustive generation.
- Assumes use of a FeatureGrammar rather than an arbitrary context free grammar
  (CFG). Note, however, that the main body of the generation algorithm still
  works off of the CFG core of the FeatureGrammar, rather than performing
  feature unification directly. The feature structures in the FeatureGrammar are
  used instead for the following two purposes:
  1. After generating an input sentence using the CFG core of the grammar,
    ExampleGenerator then verifies that the input sentence can actually be
    parsed using the full FeatureGrammar (using proper feature unification).
    If the input sentence turns out to not be parseable, then ExampleGenerator
    will discard that sentence and generate a new one in its place.
  2. If the input sentence is successfully parsed, then ExampleGenerator uses
    that input sentence as the example's "request" and then extracts the value
    of the 'sem' feature to use as the example's "reply".
- Supports generating examples that use a specified target rule.

Note on randomization and bias:

In order to minimize bias in the sampling process, the library directly
implements a random walk through the grammar, leading to a fresh implementation
independent of nltk/parse/generate.py. A more naive approach could be imagined
that attempts to reuse nltk/parse/generate.py, for example by shuffling the
grammar before generating each example as a way of introducing randomization and
then handing the shuffled grammar off to nltk/parse/generate.py to generate the
actual example. This naive approach would have the problem, however, of being
inordinately biased toward generating examples like 'jump around right thrice
and jump around right thrice' that reuse the same grammar rules repeatedly in
multiple places in the parse tree, since nltk/parse/generate.py's exhaustive
generation algorithm starts from the top of the grammar when searching for rules
to expand any given nonterminal and would thus tend to keep exploring those same
productions as much as possible. The implementation here largely avoids that
sort of bias. A small degree of bias is still present, however, in the case
where a target rule is specified, as ExampleGenerator will then greedily choose
the given target rule over other rules that output the same nonterminal in the
early stages of generation and will switch to a fully random walk only after the
target rule has been used at least once.
"""

import copy
import dataclasses
import itertools
import pprint
from typing import (Callable, Collection, Dict, Iterator, List, Optional,
                    Sequence, Tuple)

from absl import logging
import attr
import nltk
import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import distractor_generation
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import grammar_generation
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import parsing
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import rule_conversion


class MaxLevelError(Exception):
  """Maximum derivation level reached in random walk through the grammar."""


class MaxAttemptsReachedError(Exception):
  """Failed to generate example due to reaching maximum number of attempts."""


@dataclasses.dataclass
class _ProductionSamplingParameters:
  """Parameters for sampling a single production from the grammar.

  Attributes:
    grammar: The grammar to sample productions from.  Typically it should be a
      sub-grammar of self._grammar.
    unreliable_productions: Productions that should be replaced with distractor
      productions.
    yield_probability: If provided, overrides sampling options'
      derived_production_yield_probability.
    start: If provided, the starting production will be sampled among those with
      the specified FeatStructNonterminal on the LHS.
    require_derived: If True, only productions created by at least one
      composition would be returned.
  """
  grammar: nltk.grammar.FeatureGrammar
  unreliable_productions: Sequence[nltk.grammar.Production] = dataclasses.field(
      default_factory=tuple)
  yield_probability: Optional[float] = None
  start: Optional[nltk.grammar.FeatStructNonterminal] = None
  require_derived: bool = False


@dataclasses.dataclass
class _ProductionSamplingResults:
  """Results of sampling a single production from the grammar.

  Attributes:
    production: The sampled production.
    grammar: The grammar from which the production is sampled.
    dependency_productions: The productions used to build the sampled production
      via production composition.
    distractor_productions_by_unreliable_production: The mapping from original
      unreliable productions to the distractor productions that are used to
      build the sampled productions.
  """
  production: nltk.grammar.Production
  grammar: nltk.grammar.FeatureGrammar
  dependency_productions: Tuple[nltk.grammar.Production]
  distractor_productions_by_unreliable_production: (
      Dict[nltk.grammar.Production,
           List[nltk.grammar.Production]]) = dataclasses.field(
               default_factory=dict)


_PredicateOnProductionSamplingResults = (
    Callable[[_ProductionSamplingResults], bool])


def _replace_production_in_grammar(grammar,
                                   original_production,
                                   new_production):
  """Returns a new grammar with a production replaced.

  Args:
    grammar: The grammar from which a production will be replaced.
    original_production: The original production to be replaced.
    new_production: The new production to replace the original production.

  Raises:
    ValueError: If the specified original_production is not in grammar.
  """
  if original_production not in grammar.productions():
    raise ValueError(
        f'Production {original_production} not in grammar: {grammar}.')

  new_productions = []
  for production in grammar.productions():
    if production == original_production:
      new_productions.append(new_production)
    else:
      new_productions.append(production)

  new_grammar = nltk.grammar.FeatureGrammar(
      start=copy.deepcopy(grammar.start()),
      productions=copy.deepcopy(new_productions))
  return new_grammar


class ExampleGenerator:
  """Generator of conceptual learning examples based on a given FeatureGrammar.

  A separate generator should be constructed for each grammar. However, the
  same GenerationCounters object can be shared across generators in order to
  track aggregate stats for a full dataset generation run, which may involve
  many different randomly-generated grammars.
  """

  def __init__(
      self,
      grammar,
      rng,
      options = None,
      counters = None,
      grammar_generator = None,
      provenance_by_production = None):
    """Initializes the generator instance.

    Args:
      grammar: The grammar used to generate the examples.
      rng: Random number generator.
      options: Bundle of options controlling the generation algorithm.
      counters: Counters tracking various statistics about the generation
        process, such as the number of valid and invalid generation attempts. If
        a counters object is provided, ExampleGenerator will increment the
        counters in this object directly, as opposed to making a copy of it.
      grammar_generator: A grammar generator used to generate alternative
        grammars to create distractor rules.
      provenance_by_production: A ProductionProvenanceMapping to record
        production provenances of all compositions taking place during example
        generation.
    """
    self._grammar = grammar
    self._rng = rng
    self._options = options if options else inputs.SamplingOptions()
    self._counters = counters if counters else outputs.GenerationCounters()
    self._grammar_generator = (
        grammar_generator
        if grammar_generator else grammar_generation.GrammarGenerator())
    self.provenance_by_production = (
        provenance_by_production if provenance_by_production else
        production_composition.ProductionProvenanceDict())
    self._productions_by_rule = (
        rule_conversion.rule_mapping_from_grammar(self._grammar,
                                                  self._options.rule_format))

    # Inverted mapping from production strings to target rules.  This is needed
    # in order to make sure ExampleMetadata.rules contains rules in the
    # specified rule format.
    self._rule_by_production_string = {}
    for rule, productions in self._productions_by_rule.items():
      for production in productions:
        self._rule_by_production_string[str(production)] = rule

  def get_grammar(self):
    """Returns the grammar behind the generation."""
    return self._grammar

  def get_rules(self):
    return list(self._productions_by_rule.keys())

  def get_productions_without_rules(self):
    productions_without_rules = []
    for production in self._grammar.productions():
      if not any(production in productions
                 for productions in self._productions_by_rule.values()):
        productions_without_rules.append(production)
    return productions_without_rules

  def _rules_from_productions(
      self, productions):
    rules = []
    for production in productions:
      rule = rule_conversion.rule_from_production(production,
                                                  self._options.rule_format)
      if rule is not None:
        rules.append(rule)
    return rules

  def _rule_and_dependencies_from_production_sampling_results(
      self, results
  ):
    """Converts production and dependency to rule and dependency.

    Args:
      results: The results from production sampling.

    Returns:
      A tuple of (rule, dependency rules, distractors by unreliable rule
        mapping). The dependency rules are in the same order as the original
        dependency productions and are not deduplicated.

    Raises:
      ValueError: If results.production does not have a corresponding rule in
      the desired rule format, for example PassThroughRules in the
      interpretation rule format.
    """
    rule = rule_conversion.rule_from_production(results.production,
                                                self._options.rule_format)
    if rule is None:
      raise ValueError(f'Production does not have a rule in the given format:'
                       f'production={results.production}'
                       f'rule_format={self._options.rule_format}')

    dependency_rules = self._rules_from_productions(
        results.dependency_productions)

    distractor_rules_by_unreliable_rule = {}
    for unreliable_production, distractor_productions in (
        results.distractor_productions_by_unreliable_production.items()):
      unreliable_rule = rule_conversion.rule_from_production(
          unreliable_production, self._options.rule_format)

      if unreliable_rule is None:
        raise ValueError('Production does not have a rule in the given format: '
                         f'production={unreliable_production}'
                         f'rule_format={self._options.rule_format}')

      distractor_rules = self._rules_from_productions(distractor_productions)
      distractor_rules_by_unreliable_rule[unreliable_rule] = distractor_rules

    return rule, dependency_rules, distractor_rules_by_unreliable_rule

  def get_productions_from_rule(self,
                                rule):
    """Returns all the productions corresponding to the given rule string."""
    return self._productions_by_rule[rule]

  def _get_production_from_rule(self, rule):
    """Returns a random production corresponding to the given rule string.

    Args:
      rule: The rule for which to find the corresponding production.
    """
    if rule not in self._productions_by_rule:
      raise ValueError(f'Rule not found in grammar: rule={rule}, '
                       f'grammar={self._grammar}.')
    else:
      return self._rng.choice(self._productions_by_rule[rule])

  def get_example_for_explicit_rule(self, rule):
    """Returns the Example asserting the rule explicitly.

    This is used for adding examples to the context during dataset generation.

    Args:
      rule: The rule to be asserted.
    """
    production = self._get_production_from_rule(rule)
    example = cl.Example(
        request=rule,
        reply=cl.RuleReply.TRUE,
        metadata=cl.ExampleMetadata(
            rules=set([rule]),
            target_rule=rule,
            num_variables=production_composition.num_variables_from_production(
                production),
            production=production,
            production_provenance=production_composition.ProductionProvenance(
                source=production)))
    return example

  def _sample_production_and_dependencies(
      self,
      parameters):
    """Returns a random production and its dependency source rules.

    The current implementation randomly selects productions that can be composed
    and stops the random walk according to yield_probability, which defaults to
    derived_production_yield_probability in the sampling option.

    The sequence of dependency productions are in order of being used in
    production compositions. If a production is used multiple times, then it
    will appear multiple times in the sequence.

    Args:
      parameters: The parameters used for the production sampling process.

    Raises:
      MaxLevelError: If a derived production failed to be sampled due to
        reaching the maximum level.
    """
    yield_probability = (
        parameters.yield_probability if parameters.yield_probability is not None
        else self._options.derived_production_yield_probability)
    grammar = parameters.grammar

    def should_compose(dependency_productions):
      return (self._rng.random() > yield_probability or
              (parameters.require_derived and len(dependency_productions) < 2))

    def composable_indices_from_production(production):
      composable_indices = [
          i for i, item in enumerate(production.rhs())
          if isinstance(item, nltk.grammar.Nonterminal) and grammar.productions(
              lhs=item)
      ]
      return composable_indices

    if parameters.start:
      production = self._rng.choice(grammar.productions(lhs=parameters.start))
    else:
      production = self._rng.choice(grammar.productions())
    self.provenance_by_production.setdefault(
        production,
        production_composition.ProductionProvenance(source=production))
    dependency_productions = (production,)
    level = 0
    while composable_indices_from_production(production) and should_compose(
        dependency_productions):
      parent = production
      composable_indices = composable_indices_from_production(parent)
      i = self._rng.choice(composable_indices)
      item = parent.rhs()[i]
      other_parent = self._rng.choice(grammar.productions(lhs=item))
      dependency_productions += (other_parent,)
      production = production_composition.compose(parent, other_parent, i,
                                                  self.provenance_by_production)
      level += 1
      if level >= self._options.max_derivation_level:
        raise MaxLevelError(
            'Failed to sample derived production due to reaching maximum level '
            'in random walk through grammar.')

    return _ProductionSamplingResults(production, grammar,
                                      dependency_productions)

  def _generate_one_production_and_dependencies(
      self, predicates,
      parameters):
    """Returns a derived production and dependency satisfying a predicate.

    The production returned by this method is guaranteed to have a corresponding
    rule in the example generator's rule format. For example, PassThroughRules
    do not have a rule in the interpretation rule format.

    Args:
      predicates: A sequence of functions that take a production and a sequence
        of dependency productions as arguments, and return a bool value.  Only
        productions and dependencies that evaluate to True for all predicates
        will be returned.
      parameters: The parameters used for the production sampling process.

    Raises:
      MaxAttemptsReachedError: If a production could not be generated due to
        reaching the maximum number of attempts.
    """

    # For every requested example, the unreliable productions are replaced
    # with distractor productions.
    grammar = copy.deepcopy(parameters.grammar)
    unreliable_production_by_distractor_production = {}
    for unreliable_production in parameters.unreliable_productions:
      # If distractor_generation fails to create a distractor for the unreliable
      # rule after a number of attempts, we just keep it as is.
      try:
        distractor_production, grammar = (
            self._create_distractor_production_and_grammar(
                unreliable_production, grammar))
        # Keep track of the original unreliable productions so we can populate
        # ExampleMetadata.unreliable_rules and allow predicates to check against
        # unreliable dependencies in the case when the target rule itself is
        # unreliable.
        unreliable_production_by_distractor_production[
            distractor_production] = unreliable_production
      except distractor_generation.FailedToCreateDistractorError:
        continue

    parameters = dataclasses.replace(parameters, grammar=grammar)

    for _ in range(self._options.max_attempts_per_example):
      try:
        results = self._sample_production_and_dependencies(parameters)
      except MaxLevelError:
        self._counters.example_attempts.max_derivation_level_reached += 1
        continue

      rule = rule_conversion.rule_from_production(results.production,
                                                  self._options.rule_format)
      if not rule:
        continue

      distractor_productions_by_unreliable_production = {}
      for production in results.dependency_productions:
        if production in unreliable_production_by_distractor_production:
          unreliable_production = (
              unreliable_production_by_distractor_production[production])
          distractor_productions_by_unreliable_production.setdefault(
              unreliable_production, []).append(production)

      results = dataclasses.replace(
          results,
          distractor_productions_by_unreliable_production=(
              distractor_productions_by_unreliable_production))

      if (not predicates or
          all(predicate(results) for predicate in predicates)):
        return results

    self._counters.errors.failed_to_generate_derived_production += 1
    raise MaxAttemptsReachedError(
        f'Failed to generate derived production due to reaching maximum '
        f'number of attempts: predicates={predicates}, '
        f'grammar={self._grammar}, counters={self._counters}')

  def _generate_infinite_productions_and_dependencies(
      self, predicates,
      parameters
  ):
    """Yields an infinite stream of productions and dependency productions.

    Sampling is performed with replacement.  The stream ends if the number of
    attempts to generate any production reaches max_attempts_per_example.

    Args:
      predicates: A sequence of functions that take a production and a sequence
        of dependency productions as arguments, and returns a bool value.  Only
        productions and dependencies that evaluate to True for all predicates
        will be returned.
      parameters: The parameters used for the production sampling process.
    """
    while True:
      try:
        yield self._generate_one_production_and_dependencies(
            predicates, parameters)
      except MaxAttemptsReachedError:
        logging.exception('Failed to generate derived productions.')
        break

  def _generate_n_productions_and_dependencies(
      self,
      n,
      predicates,
      parameters = None
  ):
    """Yields n randomly-sampled productions and dependency productions.

    Sampling is performed with replacement, so the productions are not
    guaranteed to be unique.  Fewer than n productions would be returned if the
    maximum number of attempts is reached during the generation of one
    production.

    Args:
      n: Number of productions to return.
      predicates: A sequence of functions that take a production and a sequence
        of dependency productions as arguments, and return a bool value.  Only
        productions and dependencies that evaluate to True for all predicates
        will be returned.
      parameters: The parameters used for the production sampling process.
    """
    if parameters is None:
      parameters = _ProductionSamplingParameters(grammar=self._grammar)
    yield from itertools.islice(
        self._generate_infinite_productions_and_dependencies(
            predicates, parameters), n)

  def _avoids_rules(
      self,
      rules_to_avoid = None
  ):
    """Returns a predicate that filters out rules_to_avoid.

    Args:
      rules_to_avoid: The rules to avoid.
    """
    rules_to_avoid = frozenset(
        rules_to_avoid) if rules_to_avoid else frozenset()

    def avoids_rules(results):
      rule = rule_conversion.rule_from_production(results.production,
                                                  self._options.rule_format)
      return rule not in rules_to_avoid

    return avoids_rules

  def _uses_target_production(
      self,
      target_production = None
  ):
    """Returns a predicate that checks if target_production is depended upon.

    Args:
      target_production: The production that should appear in
        dependency_productions.
    """

    def uses_target_production(results):
      return (not target_production or
              target_production in results.dependency_productions)

    return uses_target_production

  def _uses_target_unreliable_production(
      self, target_unreliable_production
  ):
    """Returns a predicate that checks if target_unreliable_production is used.

    Args:
      target_unreliable_production: The production that should appear in
        results.distractor_productions_by_unreliable_production.
    """

    def uses_target_unreliable_production(results):
      return (target_unreliable_production
              in results.distractor_productions_by_unreliable_production)

    return uses_target_unreliable_production

  def _validate_tokenized_sentence(
      self,
      parser,
      tokenized_sentence,
      target_production = None):
    """Returns True if the tokenized sentence is valid."""
    semantics_with_source_rule_strings = tuple(
        parser.semantic_parse(tokenized_sentence))
    if not semantics_with_source_rule_strings:
      self._counters.example_attempts.unparseable += 1
      logging.warning('Skipping unparseable sentence: %s',
                      pprint.pformat(tokenized_sentence))
      return False

    if len(semantics_with_source_rule_strings) > 1:
      self._counters.example_attempts.ambiguous += 1
      logging.warning('Skipping ambiguous sentence: %s',
                      pprint.pformat(tokenized_sentence))
      return False
    _, source_rule_strings = semantics_with_source_rule_strings[0]

    if target_production and str(target_production) not in source_rule_strings:
      logging.warning(
          'Skipping sentence lacking target rule (%s): %s'
          '\n(source_rule_strings: %s)', target_production,
          pprint.pformat(tokenized_sentence), source_rule_strings)
      self._counters.example_attempts.missing_target_rule += 1
      return False

    return True

  def _valid_tokenized_sentence(
      self,
      target_production = None
  ):
    """Returns a predicate that checks if the tokenized sentence is valid.

    Args:
      target_production: The production that should appear in
        source_rule_strings when the sentence is parsed.
    """

    def valid_tokenized_sentence(results):
      parser = parsing.RuleTrackingSemanticParser(results.grammar)
      return self._validate_tokenized_sentence(parser, results.production.rhs(),
                                               target_production)

    return valid_tokenized_sentence

  def _not_source_rule(self):
    """Returns a predicate that filters out source rules.

    Since some source rule productions do not have corresponding rules in a
    given rule format (such as PassThroughRules in the interpretation rule
    format), we calculate this by making sure that the production depends on at
    least two source rules that have corresponding rules of the desired rule
    format in the current implementation.
    """

    def not_source_rule(results):
      dependency_productions_having_rules = []
      for dependency_production in results.dependency_productions:
        if str(dependency_production) in self._rule_by_production_string:
          dependency_productions_having_rules.append(dependency_production)

      return len(dependency_productions_having_rules) > 1

    return not_source_rule

  def _no_nonterminal(self):
    """Returns a predicate that filters out productions with nonterminals."""

    def no_nonterminal(results):
      return not any(
          isinstance(item, nltk.grammar.Nonterminal)
          for item in results.production.rhs())

    return no_nonterminal

  def _grammar_from_rules_to_avoid_as_dependency(
      self, rules_to_avoid_as_dependency
  ):
    """Returns productions that are allowed to be used as dependency.

    Returns self._grammar if there is nothing to avoid and therefore all
    productions in the grammar are allowed.
    The main use of this method is to create a "pruned" grammar when we know
    that some productions are not allowed during sampling (e.g. omitted rules),
    in which case it is more efficient to limit the space of productions instead
    of filtering out derived productions after sampling is done based on the
    dependencies.

    Args:
      rules_to_avoid_as_dependency: The rules whose source rule productions are
        disallowed as dependency when sampling for derived productions.
    """
    if not rules_to_avoid_as_dependency:
      return self._grammar

    disallowed_dependencies = []
    for production in self._grammar.productions():
      if (str(production) in self._rule_by_production_string) and (
          self._rule_by_production_string[str(production)]
          in rules_to_avoid_as_dependency):
        disallowed_dependencies.append(production)

    allowed_dependencies = [
        production for production in self._grammar.productions()
        if production not in disallowed_dependencies
    ]

    grammar = nltk.grammar.FeatureGrammar(
        start=copy.deepcopy(self._grammar.start()),
        productions=copy.deepcopy(allowed_dependencies))

    return grammar

  def _create_distractor_production(
      self, production
  ):
    """Returns a DistractorCreationResult.

    The strategy for creating the distractor is determined by sampling options.
    Currently the same fraction is used for both top-level examples and
    unreliable rules.

    Args:
      production: The production to create a distractor for.
    """
    if self._rng.random() < self._options.alternative_grammar_fraction:
      distractor_creation_result = (
          distractor_generation
          .create_distractor_production_with_alternative_grammar(
              production, self._grammar, self._rng, self._grammar_generator,
              self.provenance_by_production,
              self._options.max_attempts_per_negative_example))
    else:
      distractor_creation_result = (
          distractor_generation
          .create_distractor_production_with_heuristic_edit(
              production, self._grammar, self._rng, self._options.max_edits))

    return distractor_creation_result

  def _create_distractor_production_and_grammar(
      self, production,
      grammar
  ):
    """Returns a tuple of distractor production and grammar.

    A distractor production will be created and replace the specified production
    in the grammar.

    Args:
      production: The production for which a distractor production will be
        created.
      grammar: A grammar containing the target production.
    """
    distractor_creation_result = self._create_distractor_production(production)
    distractor_production = distractor_creation_result.distractor
    grammar = _replace_production_in_grammar(grammar, production,
                                             distractor_production)

    return distractor_production, grammar

  def _build_example(self, production,
                     metadata,
                     request_type):
    """Returns the Example of request type with production and metadata.

    The returned example's qualifier needs to be set by the caller.

    If request_type is NON_RULE, the production need to have no variables left
    in them, but this function does not check against that, so the caller should
    make sure of that.

    Args:
      production: The production of the example.
      metadata: The example metadata of the example.
      request_type: The request type of the example.
    """
    rule = rule_conversion.rule_from_production(production,
                                                self._options.rule_format)

    if request_type == cl.RequestType.RULE:
      example = cl.Example(
          request=rule, reply=cl.RuleReply.TRUE, metadata=metadata)
    else:
      metadata = attr.evolve(metadata, as_rule=rule)
      request = ' '.join(filter(bool, production.rhs()))
      if isinstance(production.lhs()['sem'], str):
        reply = production.lhs()['sem']
      else:
        reply = ' '.join(filter(bool, production.lhs()['sem']))
      example = cl.Example(request=request, reply=reply, metadata=metadata)

    return example

  def _example_from_production_sampling_results(
      self,
      results,
      request_type,
      target_rule = ''):
    """Returns an Example constructed with the ProductionSamplingResults.

    Args:
      results: The ProductionSamplingResults used to construct the Example.
      request_type: The Example's request type, RULE or NON_RULE.
      target_rule: If specified, used to populate ExampleMetadata.target_rule.
    """
    production = results.production
    dependency_productions = results.dependency_productions
    _, dependency_rules, distractor_rules_by_unreliable_rule = (
        self._rule_and_dependencies_from_production_sampling_results(results))

    metadata = cl.ExampleMetadata(
        rules=set(dependency_rules),
        target_rule=target_rule,
        derivation_level=len(dependency_productions) - 1,
        num_variables=production_composition.num_variables_from_production(
            production),
        distractor_rules_by_unreliable_rule=(
            distractor_rules_by_unreliable_rule),
        production=production,
        production_provenance=self.provenance_by_production[production])

    return self._build_example(production, metadata, request_type)

  def _example_from_source_examples(self, production,
                                    source_examples,
                                    request_type):
    """Returns the top-level example built from source context examples."""
    dependency_rules = set()
    derivation_level = 0
    distractor_rules_by_unreliable_rule = {}
    for source_example in source_examples:
      dependency_rules = dependency_rules.union(source_example.metadata.rules)
      derivation_level += source_example.metadata.derivation_level
      cl.extend_mapping_of_lists_unique(
          distractor_rules_by_unreliable_rule,
          source_example.metadata.distractor_rules_by_unreliable_rule)

    # Up to this point we have accounted for only the derivation level within
    # the source examples, here we include the number of compositions between
    # the source productions.
    derivation_level += (len(source_examples) - 1)

    metadata = cl.ExampleMetadata(
        rules=dependency_rules,
        derivation_level=derivation_level,
        num_variables=production_composition.num_variables_from_production(
            production),
        distractor_rules_by_unreliable_rule=(
            distractor_rules_by_unreliable_rule),
        production=production,
        production_provenance=self.provenance_by_production[production])

    return self._build_example(production, metadata, request_type)

  def _generate_n_rule_examples(
      self,
      n,
      rules_to_avoid = None,
      unreliable_rules = None):
    """Yields n derived rules and dependency rules.

    Args:
      n: Number of rules to return.
      rules_to_avoid: The function will only return rules not in this sequence,
        but dependency rules could still appear in it.
      unreliable_rules: If provided, these rules will be replaced with
        distractor rules while generating examples.  Every generated example
        will use newly created distractor rules.
    """
    unreliable_rules = unreliable_rules or []
    unreliable_productions = [
        self._get_production_from_rule(unreliable_rule)
        for unreliable_rule in unreliable_rules
    ]
    parameters = _ProductionSamplingParameters(
        grammar=self._grammar, unreliable_productions=unreliable_productions)
    predicates = [self._avoids_rules(rules_to_avoid)]
    for results in self._generate_n_productions_and_dependencies(
        n, predicates, parameters):
      example = self._example_from_production_sampling_results(
          results, request_type=cl.RequestType.RULE)

      self._counters.example_attempts.valid += 1
      yield example

  def _generate_n_examples_with_qualifier(
      self, n, inference_engine,
      context, qualifier,
      request_type):
    """Yields Examples of the specified request type and qualifier."""

    def infinite_examples_with_qualifier():
      if request_type == cl.RequestType.RULE:
        if qualifier == cl.Qualifier.M:
          productions_with_qualifier = list(
              inference_engine.monotonic_productions)
        else:
          productions_with_qualifier = list(
              inference_engine.defeasible_productions)
      else:
        # For NON_RULE requests we sample only from productions that have no
        # variables.
        if qualifier == cl.Qualifier.M:
          productions_with_qualifier = list(
              inference_engine.monotonic_productions.intersection(
                  inference_engine.get_productions_of_num_variables(
                      0)).intersection(
                          inference_engine.get_productions_of_lhs_symbol(
                              self._grammar.start()[nltk.grammar.TYPE])))
        else:
          # In rare cases the following intersection could be empty, and there
          # would be no defeasible nonrule examples.  This happens, for example,
          # if the concat rule is the only hidden rule, and both
          # D[sem=...] -> U[sem=?x1] 'right', and
          # D[sem=...] -> U[sem=?x1] 'left' appear as context examples'
          # productions.  In this case any production without variable using the
          # hidden rule can be realized as a composition of monotonic
          # productions, and is therefore monotonic.
          productions_with_qualifier = list(
              inference_engine.defeasible_productions.intersection(
                  inference_engine.get_productions_of_num_variables(
                      0)).intersection(
                          inference_engine.get_productions_of_lhs_symbol(
                              self._grammar.start()[nltk.grammar.TYPE])))

      # Hidden rules are not explicitly used for generating the context, but
      # they could still show up as source productions, so we create the
      # corresponding examples that could be used as source examples.
      source_example_by_hidden_rule = {
          hidden_rule: self.get_example_for_explicit_rule(hidden_rule)
          for hidden_rule in context.metadata.hidden_rules
      }

      # Pass-through rules do not give rise to top-level examples.  This set
      # contains not only the pass-through rules in the grammar, but also
      # compositions of pass-through rules.
      productions_to_exclude = set([
          production for production in inference_engine.all_productions
          if nltk_utils.is_pass_through_rule(production)
      ])

      productions_with_qualifier = [
          production for production in productions_with_qualifier
          if production not in productions_to_exclude
      ]

      # Since the productions were originally stored in sets, we sort this list
      # before shuffling to ensure deterministic behavior.
      productions_with_qualifier = sorted(productions_with_qualifier, key=str)

      logging.info('Sampling from %d productions for %s %s examples.',
                   len(productions_with_qualifier), qualifier, request_type)

      # We shuffle and pop from the eligible productions instead of sampling,
      # since sometimes productions_with_qualifier has very few productions in
      # it by reasons similar to how it could be empty as explained in the
      # comments above.
      self._rng.shuffle(productions_with_qualifier)
      while productions_with_qualifier:
        production = productions_with_qualifier.pop()
        source_productions = inference_engine.get_source_productions(production)

        source_examples = []
        for source_production in source_productions:
          # We discard pass-through rules among source productions since they do
          # not correspond to context examples.
          if source_production in productions_to_exclude:
            continue

          # For the purpose of building an example from source examples, we only
          # need to find source examples that have the same rule string (and
          # not necessarily the same production, where the ambiguity is caused
          # by pass-through rules).
          source_rule = rule_conversion.rule_from_production(
              source_production, self._options.rule_format)

          if source_rule in source_example_by_hidden_rule:
            source_examples.append(source_example_by_hidden_rule[source_rule])
            continue
          else:
            for context_example in context:
              context_example_request_type = context_example.get_request_type()

              if ((context_example_request_type == cl.RequestType.RULE and
                   context_example.request == source_rule) or
                  (context_example_request_type == cl.RequestType.NON_RULE and
                   context_example.metadata.as_rule == source_rule)):
                source_examples.append(context_example)
                break
            else:
              # If we get to this point, the source rule is not a pass-through
              # rule, but does not have a source example.  So something must be
              # wrong.
              raise ValueError(
                  f'Source production {source_production} not found in context '
                  'or hidden rules.'
                  f'\nrules: {context.metadata.rules}'
                  f'\nexplicit rules: {context.metadata.explicit_rules}'
                  f'\nomitted rules: {context.metadata.omitted_rules}'
                  f'\nunreliable rules: {context.metadata.unreliable_rules}'
                  f'\nhidden rules: {context.metadata.hidden_rules}'
                  f'\nproduction: {production}'
                  f'\nsource productions: {source_productions}'
                  '\ncontext productions: '
                  f'{[e.metadata.production for e in context]}')

        example = self._example_from_source_examples(production,
                                                     source_examples,
                                                     request_type)

        self._counters.example_attempts.valid += 1
        yield example

    yield from itertools.islice(infinite_examples_with_qualifier(), n)

  def _should_sample_from_grammar(
      self, inference_engine):
    # During testing we sometimes use a fake inference engine to speed up
    # tests.
    return not inference_engine.all_productions

  # This is used for top-level examples.
  def generate_n_rule_examples_with_reply_and_qualifier(
      self,
      n,
      rule_reply,
      qualifier,
      inference_engine,
      context,
      rules_to_avoid = None,
      unreliable_rules = None,
  ):
    """Yields n rule examples with specified reply and qualifier.

    Args:
      n: Number of rules to return.
      rule_reply: The reply to use for the generated examples. Should be one of
        the constants defined in cl.RuleReply.
      qualifier: Yield only examples of the provided qualifier.  Ignored if
        rule_reply is UNKNOWN (since then the qualifier must be Defeasible).
      inference_engine: The InferenceEngine instance containing information
        about the context.
      context: The context generated along with inference_engine.
      rules_to_avoid: The function will only return rules not in this sequence,
        but dependency rules could still appear in it.
      unreliable_rules: If provided, these rules will be replaced with
        distractor rules while generating examples.  Every generated example
        will use newly created distractor rules.
    """
    if (rule_reply == cl.RuleReply.UNKNOWN or
        self._should_sample_from_grammar(inference_engine)):
      rule_examples = self._generate_n_rule_examples(n, rules_to_avoid,
                                                     unreliable_rules)
    else:
      rule_examples = self._generate_n_examples_with_qualifier(
          n, inference_engine, context, qualifier, cl.RequestType.RULE)

    if rule_reply == cl.RuleReply.UNKNOWN:
      for example in rule_examples:
        if inference_engine.contains_production(example.metadata.production):
          (self._counters.example_attempts.wrong_reply_when_targeting_unknown
          ) += 1
          self._counters.example_attempts.valid -= 1
          continue

        if inference_engine.inconsistency_if_production_added(
            example.metadata.production) is not None:
          self._counters.example_attempts.rule_example_inconsistent += 1
          self._counters.example_attempts.valid -= 1

        example = attr.evolve(
            example, reply=cl.RuleReply.UNKNOWN, qualifier=cl.Qualifier.D)
        yield example

    elif rule_reply == cl.RuleReply.TRUE and qualifier == cl.Qualifier.M:
      for example in rule_examples:
        if not inference_engine.contains_monotonic_production(
            example.metadata.production):
          (self._counters.example_attempts
           .wrong_qualifier_when_targeting_monotonic) += 1
          self._counters.example_attempts.valid -= 1
          continue

        example = attr.evolve(
            example, reply=cl.RuleReply.TRUE, qualifier=qualifier)
        yield example

    elif rule_reply == cl.RuleReply.TRUE and qualifier == cl.Qualifier.D:
      for example in rule_examples:
        if not inference_engine.contains_defeasible_production(
            example.metadata.production):
          (self._counters.example_attempts
           .wrong_qualifier_when_targeting_defeasible) += 1
          self._counters.example_attempts.valid -= 1
          continue

        example = attr.evolve(
            example, reply=cl.RuleReply.TRUE, qualifier=qualifier)
        yield example

    elif rule_reply == cl.RuleReply.FALSE:
      for example in rule_examples:
        try:
          distractor_creation_result = self._create_distractor_production(
              example.metadata.production)
        except distractor_generation.FailedToCreateDistractorError:
          self._counters.example_attempts.unable_to_create_negative_example += 1
          self._counters.example_attempts.valid -= 1
          continue

        distractor_production = distractor_creation_result.distractor
        inconsistency = inference_engine.inconsistency_if_production_added(
            distractor_production)
        if inconsistency is None:
          (self._counters.example_attempts.distractor_consistent_with_context
          ) += 1
          self._counters.example_attempts.valid -= 1
          continue

        if inconsistency.type != qualifier:
          if qualifier == cl.Qualifier.M:
            (self._counters.example_attempts
             .wrong_qualifier_when_targeting_monotonic) += 1
          else:
            (self._counters.example_attempts
             .wrong_qualifier_when_targeting_defeasible) += 1

          self._counters.example_attempts.valid -= 1
          continue

        request = rule_conversion.rule_from_production(
            distractor_production, self._options.rule_format)
        original_request = example.request
        new_source_production_by_source_production = {
            str(k): str(v) for k, v in distractor_creation_result
            .new_source_production_by_source_production.items()
        }

        # Here we do not update ExampleMetadata.rules since some strategies
        # of creating negative examples do not operate on the source rule
        # level.  This means that for negative rule examples, the rules field
        # in the metadata are rules for the original positive request.
        metadata = attr.evolve(
            example.metadata,
            applied_edits=distractor_creation_result.applied_edits,
            new_source_production_by_source_production=(
                new_source_production_by_source_production),
            original_request=original_request,
            production=distractor_production)
        example = attr.evolve(
            example,
            metadata=metadata,
            request=request,
            reply=rule_reply,
            qualifier=qualifier)

        yield example

  # This is used for context examples.
  def generate_n_derived_rule_examples(
      self,
      n,
      target_rule = None,
      rules_to_avoid_as_dependency = None,
      unreliable_rules = None):
    """Yields n random derived rule Examples from the given grammar.

    In each generated example, the 'request' is a derived rule obtained by
    composing the grammar's source rules. The 'reply' is TRUE.  Sampling is
    performed with replacement, so the examples are not guaranteed to be unique.

    Each example will have empty context, but other fields will be filled out.

    Args:
      n: Number of examples to return.
      target_rule: If specified, then will only generate examples using that
        rule. Should be in the same rule format in which rules are tracked in
        ExampleMetadata (i.e., one of enums.RuleFormats).
      rules_to_avoid_as_dependency: The function will only return rules that do
        not depend on any rule in this sequence.
      unreliable_rules: If provided, these rules will be replaced with
        distractor rules while generating examples. Every generated example will
        use newly created distractor rules.
    """
    grammar = self._grammar_from_rules_to_avoid_as_dependency(
        rules_to_avoid_as_dependency)
    unreliable_rules = unreliable_rules or []
    unreliable_productions = [
        self._get_production_from_rule(unreliable_rule)
        for unreliable_rule in unreliable_rules
    ]

    if target_rule is None:
      target_production = None
      predicates = [self._not_source_rule()]
    elif target_rule in unreliable_rules:
      target_production = None
      target_unreliable_production = self._get_production_from_rule(target_rule)
      predicates = [
          self._not_source_rule(),
          self._uses_target_unreliable_production(target_unreliable_production)
      ]
    else:
      target_production = self._get_production_from_rule(target_rule)
      predicates = [
          self._not_source_rule(),
          self._uses_target_production(target_production)
      ]

    parameters = _ProductionSamplingParameters(
        grammar=grammar,
        unreliable_productions=unreliable_productions,
        require_derived=True)

    for results in self._generate_n_productions_and_dependencies(
        n, predicates, parameters):
      example = self._example_from_production_sampling_results(
          results,
          request_type=cl.RequestType.RULE,
          target_rule=target_rule or '')

      self._counters.example_attempts.valid += 1
      yield example

  # This is used for context examples.
  def generate_n_non_rule_examples(
      self,
      n,
      target_rule = None,
      rules_to_avoid_as_dependency = None,
      unreliable_rules = None):
    """Yields n random non-rule Examples from the given grammar.

    In each generated example, the 'request' is generated from the CFG core of
    the grammar and is guaranteed to be parseable by the full FeatureGrammar.
    The 'reply' corresponds to the semantics extracted from the resulting
    FeatureStruct using a special feature called 'sem'. Sampling is performed
    with replacement, so the examples are not guaranteed to be unique.

    Each example will have empty context, but other fields will be filled out.

    Args:
      n: Number of examples to return.
      target_rule: If specified, then will only generate examples using that
        rule. Should be in the same rule format in which rules are tracked in
        ExampleMetadata (i.e., one of enums.RuleFormats).
      rules_to_avoid_as_dependency: The function will only return rules that do
        not depend on any rule in this sequence.
      unreliable_rules: If provided, these rules will be replaced with
        distractor rules while generating examples. Every generated example will
        use newly created distractor rules.
    """
    grammar = self._grammar_from_rules_to_avoid_as_dependency(
        rules_to_avoid_as_dependency)
    unreliable_rules = unreliable_rules or []
    unreliable_productions = [
        self._get_production_from_rule(unreliable_rule)
        for unreliable_rule in unreliable_rules
    ]

    if target_rule is None:
      target_production = None
      predicates = [
          self._no_nonterminal(),
          self._valid_tokenized_sentence(target_production),
      ]
    elif target_rule in unreliable_rules:
      target_production = None
      target_unreliable_production = self._get_production_from_rule(target_rule)
      predicates = [
          self._no_nonterminal(),
          self._uses_target_unreliable_production(target_unreliable_production)
      ]
    else:
      target_production = self._get_production_from_rule(target_rule)
      predicates = [
          self._no_nonterminal(),
          self._valid_tokenized_sentence(target_production),
          self._uses_target_production(target_production)
      ]

    parameters = _ProductionSamplingParameters(
        grammar=grammar,
        unreliable_productions=unreliable_productions,
        yield_probability=0.0,
        start=self._grammar.start())

    for results in self._generate_n_productions_and_dependencies(
        n, predicates, parameters):
      example = self._example_from_production_sampling_results(
          results,
          request_type=cl.RequestType.NON_RULE,
          target_rule=target_rule or '')

      if self._options.rule_format == enums.RuleFormat.NATURAL_LANGUAGE:
        new_request = rule_conversion.natural_language_non_rule_request(
            example.request)
        example = attr.evolve(example, request=new_request)

      self._counters.example_attempts.valid += 1
      yield example

  # This is used for top-level examples.
  def generate_n_non_rule_examples_with_qualifier(
      self,
      n,
      unknown_reply,
      qualifier,
      inference_engine,
      context,
      unreliable_rules = None,
  ):
    """Yields n random non-rule Examples from the given grammar.

    In each generated example, the 'request' is generated from the CFG core of
    the grammar and is guaranteed to be parseable by the full FeatureGrammar.
    The 'reply' corresponds to the semantics extracted from the resulting
    FeatureStruct using a special feature called 'sem'. Sampling is performed
    with replacement, so the examples are not guaranteed to be unique.

    Each example will have empty context, but other fields will be filled out.

    If the unknown_reply and qualifier parameters are set, this method will
    yield only examples the requested attributes.

    Args:
      n: Number of examples to return.
      unknown_reply: Whether to yield examples with UNKNOWN reply.
      qualifier: Yield only examples of the provided qualifier.  Ignored if
        unknown_reply is True (since then the qualifier must be Defeasible).
      inference_engine: The InferenceEngine instance containing information
        about the context.
      context: The context generated along with inference_engine.
      unreliable_rules: If provided, these rules will be replaced with
        distractor rules while generating examples. Every generated example will
        use newly created distractor rules.
    """
    if (unknown_reply or self._should_sample_from_grammar(inference_engine)):
      non_rule_examples = self.generate_n_non_rule_examples(
          n, unreliable_rules=unreliable_rules)
    else:
      non_rule_examples = self._generate_n_examples_with_qualifier(
          n, inference_engine, context, qualifier, cl.RequestType.NON_RULE)

    if unknown_reply:
      for example in non_rule_examples:
        if inference_engine.contains_production(example.metadata.production):
          (self._counters.example_attempts.wrong_reply_when_targeting_unknown
          ) += 1
          self._counters.example_attempts.valid -= 1
          continue

        if inference_engine.inconsistency_if_production_added(
            example.metadata.production) is not None:
          self._counters.example_attempts.non_rule_example_inconsistent += 1
          self._counters.example_attempts.valid -= 1
          continue

        metadata = attr.evolve(example.metadata, original_reply=example.reply)
        example = attr.evolve(
            example,
            reply=cl.RuleReply.UNKNOWN,
            qualifier=cl.Qualifier.D,
            metadata=metadata)
        yield example

    elif qualifier == cl.Qualifier.M:
      for example in non_rule_examples:
        if not inference_engine.contains_monotonic_production(
            example.metadata.production):
          (self._counters.example_attempts
           .wrong_qualifier_when_targeting_monotonic) += 1
          self._counters.example_attempts.valid -= 1
          continue

        example = attr.evolve(example, qualifier=cl.Qualifier.M)
        yield example

    elif qualifier == cl.Qualifier.D:
      for example in non_rule_examples:
        if not inference_engine.contains_defeasible_production(
            example.metadata.production):
          (self._counters.example_attempts
           .wrong_qualifier_when_targeting_defeasible) += 1
          self._counters.example_attempts.valid -= 1
          continue

        example = attr.evolve(example, qualifier=cl.Qualifier.D)
        yield example
