# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import logging
from typing import Iterable

from absl.testing import absltest
from absl.testing import parameterized
import nltk
import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import grammar_loader
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import sampling
from conceptual_learning.cscan import test_utils

# Simple grammar for testing. Derived production have derivation levels up to 2.
_GRAMMAR_STRING_WITH_MAX_DERIVATION_LEVEL_2 = """
  % start D
  D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
  U[sem='WALK'] -> 'walk'
  W[sem='LTURN'] -> 'left'
  """


def _create_counters_with_non_zero_values():
  """Returns a GenerationCounters with arbitrary initial values for testing.

  We use non-zero initial values so that we can verify that ExampleGenerator
  increments the counters rather than simply overwritten them.
  """
  return outputs.GenerationCounters(
      example_attempts=outputs.ExampleAttemptCounters(
          ambiguous=1,
          max_derivation_level_reached=2,
          missing_target_rule=3,
          unparseable=4,
          valid=5,
      ),
      errors=outputs.GenerationErrorCounters(
          failed_to_generate_derived_production=1,),
  )


def _log_examples(message, examples):
  logging.info(
      '\n%s: \n%s', message,
      '\n'.join(f'{i} {example}' for i, example in enumerate(examples)))


def _assert_production_provenance_populated_properly(
    test_case, examples):
  """Asserts several invariants about the examples' production provenances."""
  examples_with_empty_provenance = set()
  examples_whose_compositions_dont_match_derivation_level = set()
  for example in examples:
    if not example.metadata.production_provenance:
      examples_with_empty_provenance.add(example)
    elif (len(example.metadata.production_provenance.compositions) !=
          example.metadata.derivation_level):
      examples_whose_compositions_dont_match_derivation_level.add(example)

  with test_case.subTest('provenance_always_populated'):
    test_case.assertEmpty(examples_with_empty_provenance)
  with test_case.subTest('provenance_has_one_composition_per_derivation_level'):
    test_case.assertEmpty(
        examples_whose_compositions_dont_match_derivation_level)


class GenerateNonRuleExampleTest(parameterized.TestCase):

  def assertSameExampleRepeatedNTimes(self, examples,
                                      expected_example_string, n):
    self.assertEqual(
        tuple(itertools.repeat(expected_example_string, n)),
        tuple(str(example) for example in examples))

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('trivial_one_token_grammar', '<{}, left, LTURN, M>', """
          % start W
          W[sem='LTURN'] -> 'left'
          """),
      ('trivial_two_token_grammar', '<{}, walk left, LTURN WALK, M>', """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
  )
  def test_generation_from_grammar_that_admits_only_one_possible_example(
      self, only_example_string_possible, grammar_string):
    """Should generate the requested number of examples, repeating as needed."""
    # For these grammars, there is only one possible example that can be
    # created, which means we can predict precisely what example should be
    # returned, despite the use of randomization during generation.
    #
    # Note that even if the grammar were to admit multiple possible examples,
    # it is still possible for the same example to appear more than once in the
    # output, as sampling is done with replacement. This test represents an
    # extreme case, in which we know for sure that the sample example will be
    # repeated over and over.
    #
    # Note also that it is important to test both single-token and multi-token
    # requests and replies, as the logic for converting these from feature
    # structures to strings can differ. (E.g., at one point there was a bug in
    # which a single token reply 'LTURN' would incorrectly become 'L T U R N'.)
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string), rng=self.rng)
    n = 5
    self.assertSameExampleRepeatedNTimes(
        example_generator.generate_n_non_rule_examples(n),
        only_example_string_possible, n)

  def test_should_not_generate_the_same_example_every_time_within_a_run(self):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        rng=self.rng)
    examples = tuple(example_generator.generate_n_non_rule_examples(5))
    _log_examples(self.id(), examples)
    self.assertGreaterEqual(len(set(examples)), 2)

  def test_should_not_generate_the_same_examples_every_time_across_runs(self):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        rng=self.rng)
    self.assertNotEqual(
        set(example_generator.generate_n_non_rule_examples(5)),
        set(example_generator.generate_n_non_rule_examples(5)))

  @parameterized.named_parameters(('1', 1), ('10', 10))
  def test_should_generate_the_requested_number_of_examples(self, n):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        rng=self.rng)
    examples = tuple(example_generator.generate_n_non_rule_examples(n))
    self.assertLen(examples, n)

  def test_should_increment_valid_attempt_counter_accurately(self):
    counters = _create_counters_with_non_zero_values()
    initial_valid = counters.example_attempts.valid
    example_generator = sampling.ExampleGenerator(
        grammar=grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        counters=counters,
        rng=self.rng)
    examples = tuple(example_generator.generate_n_non_rule_examples(10))
    self.assertLen(examples, counters.example_attempts.valid - initial_valid)

  def test_should_use_target_rule_if_specified(self):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        rng=self.rng)
    target_rule = '[x1 twice] = [x1] [x1]'
    examples = tuple(
        example_generator.generate_n_non_rule_examples(1, target_rule))
    _log_examples(f'{self.id()}: {target_rule}', examples)
    self.assertIn(target_rule, examples[0].metadata.rules)

  def test_should_accept_target_rule_in_interpretation_rule_format(self):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE),
        rng=self.rng)
    target_rule = '[x1 twice] = [x1] [x1]'
    examples = tuple(
        example_generator.generate_n_non_rule_examples(1, target_rule))
    _log_examples(f'{self.id()}: {target_rule}', examples)
    self.assertIn(target_rule, examples[0].metadata.rules)

  def test_should_accept_target_rule_in_natural_language_format(self):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.NATURAL_LANGUAGE),
        rng=self.rng)
    target_rule = ('The interpretation of a phrase x1 followed by "twice" is '
                   'the interpretation of x1 repeated twice.')
    examples = tuple(
        example_generator.generate_n_non_rule_examples(1, target_rule))
    _log_examples(f'{self.id()}: {target_rule}', examples)
    self.assertIn(target_rule, examples[0].metadata.rules)

  def test_should_honor_max_derivation_level_if_specified(self):
    grammar_string = """
        % start D
        D[sem='LTURN'] -> 'left'
        D[sem=?x1] -> U[sem=?x1]
        U[sem='WALK'] -> 'walk'
        """
    only_level_0_example_string_possible = '<{}, left, LTURN, M>'
    example_generator = sampling.ExampleGenerator(
        grammar=nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        options=inputs.SamplingOptions(max_derivation_level=0),
        rng=self.rng)
    n = 5
    self.assertSameExampleRepeatedNTimes(
        example_generator.generate_n_non_rule_examples(n),
        only_level_0_example_string_possible, n)

  @parameterized.named_parameters(
      # Here we force generation to fail by specifying a grammar from which the
      # only sentences that can be generated would have a parse tree deeper than
      # the specified maximum level.
      ('fails_due_to_generate_example_satisfying_max_derivation_level', """
          % start D
          D[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """, inputs.SamplingOptions(max_derivation_level=1), None),
      # Here we force generation to fail by specifying as target rule a rule
      # that does exist in the grammar but which is unreachable from the
      # grammar's start symbol.
      ('fails_due_to_generate_example_using_target_rule', """
          % start D
          D[sem='LTURN'] -> 'left'
          U[sem='WALK'] -> 'walk'
          """, inputs.SamplingOptions(max_derivation_level=1), '[walk] = WALK'),
  )
  def test_should_give_up_gracefully_when_max_attempts_per_example_is_reached(
      self, grammar_string, options, target_rule):
    counters = _create_counters_with_non_zero_values()
    initial_total = counters.example_attempts.get_total()
    initial_valid = counters.example_attempts.valid
    initial_failed_to_generate_derived_production = (
        counters.errors.failed_to_generate_derived_production)

    example_generator = sampling.ExampleGenerator(
        grammar=nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        options=options,
        counters=counters,
        rng=self.rng)
    examples = tuple(
        example_generator.generate_n_non_rule_examples(
            5, target_rule=target_rule))

    with self.subTest('should_perform_the_requested_number_of_attempts'):
      self.assertEqual(initial_total + options.max_attempts_per_example,
                       counters.example_attempts.get_total())
    with self.subTest('none_of_the_attempts_should_succeed'):
      self.assertEqual(initial_valid, counters.example_attempts.valid)
    with self.subTest('should_increment_error_counter'):
      self.assertEqual(initial_failed_to_generate_derived_production + 1,
                       counters.errors.failed_to_generate_derived_production)
    with self.subTest('should_return_no_examples'):
      self.assertEmpty(examples)

  def test_should_filter_out_ambiguous_sentences(self):
    # Note that while the CFG core of the below grammar admits one possible
    # sentence ('walk'), this sentence could be parsed in two different ways,
    # yielding a semantics of either 'WALK' or 'JUMP'.
    grammar_string = """
        % start D
        D[sem=?x1] -> U[sem=?x1]
        U[sem='WALK'] -> 'walk'
        U[sem='JUMP'] -> 'walk'
        """

    options = inputs.SamplingOptions()
    counters = _create_counters_with_non_zero_values()
    initial_ambiguous = counters.example_attempts.ambiguous
    example_generator = sampling.ExampleGenerator(
        grammar=nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        options=options,
        counters=counters,
        rng=self.rng)
    examples = tuple(example_generator.generate_n_non_rule_examples(5))

    with self.subTest('should_increment_the_relevant_attempt_counter'):
      self.assertEqual(initial_ambiguous + options.max_attempts_per_example,
                       counters.example_attempts.ambiguous)
    with self.subTest('should_return_no_examples'):
      self.assertEmpty(examples)

  def test_should_filter_out_unparseable_sentences(self):
    # Note that while the CFG core of the below grammar admits one possible
    # sentence ('walk'), when the additional 'constraint' tag is taken into
    # account, the sentence 'walk' is not actually parseable under the full
    # feature grammar.
    grammar_string = """
        % start D
        D[sem=?x1] -> U[sem=?x1, constraint='value1']
        U[sem='WALK', constraint='value2'] -> 'walk'
        """

    options = inputs.SamplingOptions()
    counters = _create_counters_with_non_zero_values()
    initial_unparseable = counters.example_attempts.unparseable
    example_generator = sampling.ExampleGenerator(
        grammar=nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        options=options,
        counters=counters,
        rng=self.rng)
    examples = tuple(example_generator.generate_n_non_rule_examples(5))

    with self.subTest('should_increment_the_relevant_attempt_counter'):
      self.assertEqual(initial_unparseable + options.max_attempts_per_example,
                       counters.example_attempts.unparseable)
    with self.subTest('should_return_no_examples'):
      self.assertEmpty(examples)

  def test_should_raise_error_if_target_rule_not_in_grammar(self):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        rng=self.rng)
    target_rule = 'Some rule not in grammar'
    examples = example_generator.generate_n_non_rule_examples(1, target_rule)
    self.assertRaisesRegex(ValueError, 'Rule not found in grammar', tuple,
                           examples)

  @parameterized.named_parameters(
      ('trivial_one_token_grammar', set(['[left] = LTURN']), """
          % start W
          W[sem='LTURN'] -> 'left'
          """),
      ('trivial_two_token_grammar',
       set([
           '[x1 x2] = [x2] [x1]',
           '[walk] = WALK',
           '[left] = LTURN',
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('simple_two_layer_grammar',
       set(['[x1 and x2] = [x1] [x2]', '[walk] = WALK']), """
          % start D
          D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]
          V[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """),
  )
  def test_should_populate_example_metadata_rules(
      self, possible_example_metadata_rules, grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_non_rule_examples(n)
    example_metadata_rules = set()
    for example in examples:
      example_metadata_rules = example_metadata_rules.union(
          example.metadata.rules)
    self.assertNotEmpty(example_metadata_rules)
    self.assertLessEqual(example_metadata_rules,
                         possible_example_metadata_rules)

  def test_should_populate_example_metadata_derivation_levels(self):
    grammar_string = _GRAMMAR_STRING_WITH_MAX_DERIVATION_LEVEL_2
    possible_metadata_derivation_levels = set([2])
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_non_rule_examples(n)
    example_metadata_derivation_levels = set()
    for example in examples:
      example_metadata_derivation_levels.add(example.metadata.derivation_level)
    self.assertNotEmpty(example_metadata_derivation_levels)
    self.assertLessEqual(example_metadata_derivation_levels,
                         possible_metadata_derivation_levels)

  def test_should_populate_production_provenance(self):
    grammar_string = _GRAMMAR_STRING_WITH_MAX_DERIVATION_LEVEL_2
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_non_rule_examples(n)
    _assert_production_provenance_populated_properly(self, examples)

  @parameterized.named_parameters(
      ('four_token_grammar_avoid_walk', ['[walk] = WALK'],
       set([
           '<{}, walk right, RTURN WALK, M>', '<{}, walk left, LTURN WALK, M>'
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          U[sem='RUN'] -> 'run'
          W[sem='LTURN'] -> 'left'
          W[sem='RTURN'] -> 'right'
          """),
      ('four_token_grammar_avoid_left', ['[left] = LTURN'],
       set(['<{}, run left, LTURN RUN, M>', '<{}, walk left, LTURN WALK, M>'
           ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          U[sem='RUN'] -> 'run'
          W[sem='LTURN'] -> 'left'
          W[sem='RTURN'] -> 'right'
          """),
  )
  def test_should_avoid_dependencies_if_specified(self,
                                                  rules_to_avoid_as_dependency,
                                                  should_avoid_example_strings,
                                                  grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_non_rule_examples(
        n, rules_to_avoid_as_dependency=rules_to_avoid_as_dependency)
    example_strings = set(str(example) for example in examples)
    failed_to_avoid = example_strings.intersection(should_avoid_example_strings)

    self.assertEmpty(failed_to_avoid)

  @parameterized.named_parameters(('KNOWN_M', False, cl.Qualifier.M),
                                  ('KNOWN_D', False, cl.Qualifier.D),
                                  ('UNKNOWN_D', True, cl.Qualifier.D))
  def test_should_generate_examples_with_specified_reply_and_qualifier(
      self, unknown_reply, qualifier):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        rng=self.rng)

    # The fake inference engine has empty all_productions, so the example
    # generation method will fall back to sample productions from the grammar.
    inference_engine = test_utils.make_fake_inference_engine()
    context = cl.FrozenExampleSet()
    n = 20

    examples = list(
        example_generator.generate_n_non_rule_examples_with_qualifier(
            n=n,
            unknown_reply=unknown_reply,
            qualifier=qualifier,
            inference_engine=inference_engine,
            context=context))

    generated_replies = set(example.reply for example in examples)
    generated_qualifiers = set(example.qualifier for example in examples)

    if unknown_reply:
      self.assertEqual(generated_replies, {cl.RuleReply.UNKNOWN})

    self.assertEqual(generated_qualifiers, {qualifier})


class GenerateDerivedProductionTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('trivial_one_token_grammar', set(["W[sem='LTURN'] -> 'left'"]), """
          % start W
          W[sem='LTURN'] -> 'left'
          """),
      ('trivial_two_token_grammar',
       set([
           'D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]',
           "U[sem='WALK'] -> 'walk'", "W[sem='LTURN'] -> 'left'",
           "D[sem=(LTURN+?x1)] -> U[sem=?x1] 'left'",
           "D[sem=(?x1+WALK)] -> 'walk' W[sem=?x1]",
           "D[sem=(LTURN, WALK)] -> 'walk' 'left'"
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('simple_two_layer_grammar',
       set([
           "V[sem='WALK'] -> 'walk'",
           "D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' U[sem=?x2]",
           "U[sem='WALK'] -> 'walk'",
           "D[sem=(WALK, WALK)] -> 'walk' 'and' 'walk'",
           "D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]",
           "D[sem=(?x1+WALK)] -> V[sem=?x1] 'and' 'walk'",
           "D[sem=(WALK+?x1)] -> 'walk' 'and' V[sem=?x1]",
           "D[sem=(?x1+WALK)] -> U[sem=?x1] 'and' 'walk'",
           "D[sem=(WALK+?x1)] -> 'walk' 'and' U[sem=?x1]",
           'V[sem=?x1] -> U[sem=?x1]',
           "D[sem=(?x1+?x2)] -> U[sem=?x1] 'and' V[sem=?x2]",
           "D[sem=(?x1+?x2)] -> U[sem=?x1] 'and' U[sem=?x2]",
       ]), """
          % start D
          D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]
          V[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """),
  )
  def test_should_sample_productions(self, possible_production_strings,
                                     grammar_string):
    # These are sufficiently small grammars for which the full set of possible
    # derived productions is not too large.
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            max_attempts_per_example=10,
            rule_format=enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION))

    def no_filter(results):
      del results
      return True

    n = 100
    results = (
        example_generator._generate_n_productions_and_dependencies(
            n, [no_filter]))
    production_strings = [str(result.production) for result in results]

    self.assertNotEmpty(production_strings)
    self.assertLessEqual(set(production_strings), possible_production_strings)

  def test_should_skip_productions_without_rule(self):
    grammar_string = """
    % start D
    D[sem=?x1] -> V[sem=?x1]
    V[sem=?x1] -> U[sem=?x1]
    U[sem='WALK'] -> 'walk'
    """
    disallowed_production_strings = [
        'D[sem=?x1] -> U[sem=?x1]', 'D[sem=?x1] -> V[sem=?x1]',
        'V[sem=?x1] -> U[sem=?x1]'
    ]
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            max_attempts_per_example=10,
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))

    def no_filter(results):
      del results
      return True

    n = 100
    results = (
        example_generator._generate_n_productions_and_dependencies(
            n, [no_filter]))
    production_strings = [str(result.production) for result in results]
    self.assertNotEmpty(production_strings)
    self.assertEmpty(
        set(production_strings).intersection(disallowed_production_strings))

  def test_should_return_correct_dependency_productions(self):
    grammar_string = """
    % start D
    D[sem=(?x1+?x2)] -> U[sem=?x1] 'and' U[sem=?x2]
    U[sem='WALK'] -> 'walk'
    """
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            max_attempts_per_example=10,
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            # We use a low yield probability to make sure the generated example
            # passes the only_target_production check.
            derived_production_yield_probability=0.0))

    def only_target_production(results):
      return str(
          results.production) == "D[sem=(WALK, WALK)] -> 'walk' 'and' 'walk'"

    expected_dependency_production_strings = [
        "D[sem=(?x1+?x2)] -> U[sem=?x1] 'and' U[sem=?x2]",
        "U[sem='WALK'] -> 'walk'", "U[sem='WALK'] -> 'walk'"
    ]

    result = list(
        example_generator._generate_n_productions_and_dependencies(
            1, [only_target_production]))[0]
    dependency_production_strings = [
        str(dependency_production)
        for dependency_production in result.dependency_productions
    ]
    self.assertEqual(dependency_production_strings,
                     expected_dependency_production_strings)

  def test_should_raise_max_attempts_reached_error_if_failed_to_generate(self):
    grammar_string = """
      % start W
      W[sem='LTURN'] -> 'left'
      """

    def avoids_only_possible_rule(results):
      return str(results.production) != "W[sem='LTURN'] -> 'left'"

    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(max_attempts_per_example=10))
    parameters = sampling._ProductionSamplingParameters(
        grammar=example_generator._grammar)
    with self.assertRaisesRegex(
        sampling.MaxAttemptsReachedError,
        'Failed to generate derived production due to reaching maximum number '
        'of attempts'):
      example_generator._generate_one_production_and_dependencies(
          [avoids_only_possible_rule], parameters)

  def test_should_raise_error_if_grammar_has_infinite_recursion(self):
    # This is testing the inner-most production sampling function.  In practice
    # the MaxLevelError would be captured by the caller and the dataset
    # generation process will continue.
    grammar_string = """
      % start D
      D[sem=?x1] -> D[sem=?x1]
      """
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string), rng=self.rng)
    parameters = sampling._ProductionSamplingParameters(
        grammar=example_generator._grammar, yield_probability=0.0)
    with self.assertRaisesRegex(
        sampling.MaxLevelError, 'Failed to sample derived production due to '
        'reaching maximum level in random walk through grammar.'):
      example_generator._sample_production_and_dependencies(parameters)


class GenerateRuleExampleTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('trivial_one_token_grammar', set(['<{}, [left] = LTURN, 1, M>']), """
          % start W
          W[sem='LTURN'] -> 'left'
          """),
      ('trivial_two_token_grammar',
       set([
           '<{}, [x1 left] = LTURN [x1], 1, M>', '<{}, [left] = LTURN, 1, M>',
           '<{}, [x1 x2] = [x2] [x1], 1, M>',
           '<{}, [walk left] = LTURN WALK, 1, M>', '<{}, [walk] = WALK, 1, M>',
           '<{}, [walk x1] = [x1] WALK, 1, M>'
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('simple_two_layer_grammar',
       set([
           '<{}, [x1 and walk] = [x1] WALK, 1, M>',
           '<{}, [walk and walk] = WALK WALK, 1, M>',
           '<{}, [walk and x1] = WALK [x1], 1, M>', '<{}, [walk] = WALK, 1, M>',
           '<{}, [x1 and x2] = [x1] [x2], 1, M>'
       ]), """
          % start D
          D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]
          V[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """),
  )
  def test_should_sample_rule_examples(self, possible_example_strings,
                                       grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator._generate_n_rule_examples(n)
    example_strings = set(str(example) for example in examples)
    self.assertNotEmpty(example_strings)
    self.assertLessEqual(example_strings, possible_example_strings)

  @parameterized.named_parameters(
      ('trivial_one_token_grammar', set(['[left] = LTURN']), """
          % start W
          W[sem='LTURN'] -> 'left'
          """),
      ('trivial_two_token_grammar',
       set([
           '[x1 x2] = [x2] [x1]',
           '[walk] = WALK',
           '[left] = LTURN',
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('simple_two_layer_grammar',
       set(['[x1 and x2] = [x1] [x2]', '[walk] = WALK']), """
          % start D
          D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]
          V[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """),
  )
  def test_should_populate_example_metadata_rules(
      self, possible_example_metadata_rules, grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100

    examples = example_generator._generate_n_rule_examples(n)
    example_metadata_rules = set()
    for example in examples:
      example_metadata_rules = example_metadata_rules.union(
          example.metadata.rules)
    self.assertNotEmpty(example_metadata_rules)
    self.assertLessEqual(example_metadata_rules,
                         possible_example_metadata_rules)

  def test_should_populate_example_metadata_derivation_levels(self):
    grammar_string = _GRAMMAR_STRING_WITH_MAX_DERIVATION_LEVEL_2
    possible_metadata_derivation_levels = set([0, 1, 2])
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator._generate_n_rule_examples(n)
    example_metadata_derivation_levels = set()
    for example in examples:
      example_metadata_derivation_levels.add(example.metadata.derivation_level)
    self.assertNotEmpty(example_metadata_derivation_levels)
    self.assertLessEqual(example_metadata_derivation_levels,
                         possible_metadata_derivation_levels)

  def test_should_populate_production_provenance(self):
    grammar_string = _GRAMMAR_STRING_WITH_MAX_DERIVATION_LEVEL_2
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator._generate_n_rule_examples(n)
    _assert_production_provenance_populated_properly(self, examples)

  @parameterized.named_parameters(
      ('trivial_two_token_grammar',
       ['[x1 x2] = [x2] [x1]', '[walk] = WALK', '[left] = LTURN'], """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('simple_two_layer_grammar', ['[x1 and x2] = [x1] [x2]', '[walk] = WALK'
                                   ], """
          % start D
          D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]
          V[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """),
  )
  def test_should_avoid_specified_rules(self, rules_to_avoid, grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            # We use a low yield probability here to make sure some examples are
            # generated that do not get filtered out.
            derived_production_yield_probability=0.1))
    n = 100
    examples = example_generator._generate_n_rule_examples(
        n, rules_to_avoid=rules_to_avoid)
    sampled_rules = set(example.request for example in examples)
    self.assertNotEmpty(sampled_rules)
    self.assertEmpty(sampled_rules.intersection(rules_to_avoid))

  @parameterized.named_parameters(
      ('TRUE_M', cl.RuleReply.TRUE, cl.Qualifier.M),
      ('TRUE_D', cl.RuleReply.TRUE, cl.Qualifier.D),
      ('FALSE_M', cl.RuleReply.FALSE, cl.Qualifier.M),
      ('FALSE_D', cl.RuleReply.FALSE, cl.Qualifier.D),
      ('UNKNOWN_D', cl.RuleReply.UNKNOWN, cl.Qualifier.D))
  def test_should_generate_examples_with_specified_reply_and_qualifier(
      self, rule_reply, qualifier):
    example_generator = sampling.ExampleGenerator(
        grammar_loader.load_standard_grammar(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED),
        rng=self.rng)

    # The fake inference engine has empty all_productions, so the example
    # generation method will fall back to sample productions from the grammar.
    inference_engine = test_utils.make_fake_inference_engine()
    context = cl.FrozenExampleSet()
    n = 20

    examples = list(
        example_generator.generate_n_rule_examples_with_reply_and_qualifier(
            n=n,
            rule_reply=rule_reply,
            qualifier=qualifier,
            inference_engine=inference_engine,
            context=context))

    generated_replies = set(example.reply for example in examples)
    generated_qualifiers = set(example.qualifier for example in examples)

    self.assertEqual(generated_replies, {rule_reply})
    self.assertEqual(generated_qualifiers, {qualifier})


class GenerateDerivedRuleExampleTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('trivial_two_token_grammar',
       set([
           '<{}, [walk x1] = [x1] WALK, 1, M>',
           '<{}, [x1 left] = LTURN [x1], 1, M>',
           '<{}, [walk left] = LTURN WALK, 1, M>'
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('simple_two_layer_grammar',
       set([
           '<{}, [walk and walk] = WALK WALK, 1, M>',
           '<{}, [x1 and walk] = [x1] WALK, 1, M>',
           '<{}, [walk and x1] = WALK [x1], 1, M>'
       ]), """
          % start D
          D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]
          V[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """),
  )
  def test_should_sample_derived_rule_examples(self, possible_example_strings,
                                               grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            # The generate_n_derived_rule_examples filters out examples that
            # have derivation level 0, so we use a low yield probability to make
            # sure some derived rules of positive derivation level are actually
            # generated.
            derived_production_yield_probability=0.1))
    n = 100
    examples = example_generator.generate_n_derived_rule_examples(n)
    example_strings = set(str(example) for example in examples)
    self.assertNotEmpty(example_strings)
    self.assertLessEqual(example_strings, possible_example_strings)

  @parameterized.named_parameters(
      ('trivial_two_token_grammar_target_walk', '[walk] = WALK',
       set([
           '<{}, [walk x1] = [x1] WALK, 1, M>',
           '<{}, [walk left] = LTURN WALK, 1, M>'
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('trivial_two_token_grammar_target_left', '[left] = LTURN',
       set([
           '<{}, [x1 left] = LTURN [x1], 1, M>',
           '<{}, [walk left] = LTURN WALK, 1, M>'
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
  )
  def test_should_use_target_rule_if_specified(self, target_rule,
                                               possible_example_strings,
                                               grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_derived_rule_examples(
        n, target_rule)
    example_strings = set(str(example) for example in examples)
    self.assertNotEmpty(example_strings)
    self.assertLessEqual(example_strings, possible_example_strings)

  @parameterized.named_parameters(
      ('trivial_two_token_grammar_avoid_walk', ['[walk] = WALK'],
       set([
           '<{}, [walk x1] = [x1] WALK, 1, M>',
           '<{}, [walk left] = LTURN WALK, 1, M>'
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('trivial_two_token_grammar_avoid_left', ['[left] = LTURN'],
       set([
           '<{}, [x1 left] = LTURN [x1], 1, M>',
           '<{}, [walk left] = LTURN WALK, 1, M>'
       ]), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
  )
  def test_should_avoid_dependencies_if_specified(self,
                                                  rules_to_avoid_as_dependency,
                                                  should_avoid_example_strings,
                                                  grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_derived_rule_examples(
        n, rules_to_avoid_as_dependency=rules_to_avoid_as_dependency)
    example_strings = set(str(example) for example in examples)
    failed_to_avoid = example_strings.intersection(should_avoid_example_strings)

    self.assertEmpty(failed_to_avoid)

  @parameterized.named_parameters(
      ('trivial_two_token_grammar',
       set(['[x1 x2] = [x2] [x1]', '[left] = LTURN', '[walk] = WALK']), """
          % start D
          D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
          U[sem='WALK'] -> 'walk'
          W[sem='LTURN'] -> 'left'
          """),
      ('simple_two_layer_grammar',
       set(['[walk] = WALK', '[x1 and x2] = [x1] [x2]']), """
          % start D
          D[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]
          V[sem=?x1] -> U[sem=?x1]
          U[sem='WALK'] -> 'walk'
          """),
  )
  def test_should_populate_example_metadata_rules(
      self, possible_example_metadata_rules, grammar_string):
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            derived_production_yield_probability=0.1))
    n = 100
    examples = example_generator.generate_n_derived_rule_examples(n)
    example_metadata_rules = set()
    for example in examples:
      example_metadata_rules = example_metadata_rules.union(
          example.metadata.rules)
    self.assertNotEmpty(example_metadata_rules)
    self.assertLessEqual(example_metadata_rules,
                         possible_example_metadata_rules)

  def test_should_populate_example_metadata_derivation_levels(self):
    grammar_string = _GRAMMAR_STRING_WITH_MAX_DERIVATION_LEVEL_2
    possible_metadata_derivation_levels = set([1, 2])
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_derived_rule_examples(n)
    example_metadata_derivation_levels = set()
    for example in examples:
      example_metadata_derivation_levels.add(example.metadata.derivation_level)
    self.assertNotEmpty(example_metadata_derivation_levels)
    self.assertLessEqual(example_metadata_derivation_levels,
                         possible_metadata_derivation_levels)

  def test_should_populate_production_provenance(self):
    grammar_string = _GRAMMAR_STRING_WITH_MAX_DERIVATION_LEVEL_2
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))
    n = 100
    examples = example_generator.generate_n_derived_rule_examples(n)
    _assert_production_provenance_populated_properly(self, examples)


class GetExampleForExplicitRuleTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  def test_should_create_correct_example_and_metadata_asserting_specified_rule(
      self):
    grammar_string = """
    % start D
    D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
    U[sem='WALK'] -> 'walk'
    W[sem='LTURN'] -> 'left'
    """
    feature_grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
    example_generator = sampling.ExampleGenerator(
        feature_grammar,
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))

    rule = '[x1 x2] = [x2] [x1]'
    generated_example = example_generator.get_example_for_explicit_rule(rule)

    expected_production = feature_grammar.productions()[0]
    expected_example = cl.Example(
        request=rule,
        reply=cl.RuleReply.TRUE,
        metadata=cl.ExampleMetadata(
            rules={rule},
            target_rule=rule,
            num_variables=2,
            production=expected_production,
            production_provenance=production_composition.ProductionProvenance(
                source=expected_production)))

    with self.subTest('correct_example'):
      self.assertEqual(generated_example, expected_example)
    with self.subTest('correct_metadata'):
      self.assertEqual(
          str(generated_example.metadata), str(expected_example.metadata))


class ExampleFromSourceExamplesTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(('rule_example', cl.RequestType.RULE,
                                   cl.Example(
                                       request='[walk] = WALK',
                                       reply=cl.RuleReply.TRUE,
                                       metadata=cl.ExampleMetadata(
                                           rules={'rule1', 'rule2', 'rule3'},
                                           derivation_level=2 + 3 + 1,
                                           distractor_rules_by_unreliable_rule={
                                               'unreliable_rule1': ['rule1'],
                                               'unreliable_rule3': ['rule3']
                                           }))),
                                  ('nonrule_example', cl.RequestType.NON_RULE,
                                   cl.Example(
                                       request='walk',
                                       reply='WALK',
                                       metadata=cl.ExampleMetadata(
                                           rules={'rule1', 'rule2', 'rule3'},
                                           derivation_level=2 + 3 + 1,
                                           distractor_rules_by_unreliable_rule={
                                               'unreliable_rule1': ['rule1'],
                                               'unreliable_rule3': ['rule3']
                                           }))))
  def test_should_create_example_from_source_examples(self, request_type,
                                                      expected_example):
    # Unused placeholder grammar.
    grammar_string = """
    % start X
    X -> X
    """
    example_generator = sampling.ExampleGenerator(
        nltk.grammar.FeatureGrammar.fromstring(grammar_string),
        rng=self.rng,
        options=inputs.SamplingOptions(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE))

    source_examples = [
        cl.Example(
            request='q1',
            reply='r1',
            metadata=cl.ExampleMetadata(
                rules={'rule1', 'rule2'},
                derivation_level=2,
                distractor_rules_by_unreliable_rule={
                    'unreliable_rule1': ['rule1']
                })),
        cl.Example(
            request='q2',
            reply='r2',
            metadata=cl.ExampleMetadata(
                rules={'rule2', 'rule3'},
                derivation_level=3,
                distractor_rules_by_unreliable_rule={
                    'unreliable_rule3': ['rule3']
                })),
    ]
    production_string = "U[sem='WALK'] -> 'walk'"
    production = nltk_utils.production_from_production_string(production_string)

    # Here we are just constructing a trivial provenance for the production to
    # avoid errors during the test. (In reality, if the production were
    # constructed from multiple rules, it would presumably have a more complex
    # provenance. But then the production itself would be more complex too...)
    example_generator.provenance_by_production.setdefault(
        production,
        production_composition.ProductionProvenance(source=production))

    example = example_generator._example_from_source_examples(
        production, source_examples, request_type)

    self.assertEqual(
        example, expected_example, msg=f'{example}, {example.metadata}')


if __name__ == '__main__':
  absltest.main()
