# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import itertools
from typing import Any, Iterable, Mapping, Set

from absl.testing import absltest
from absl.testing import parameterized
import nltk

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition


def _get_production_set_from_index(
    index):
  """Returns the productions occurring anywhere in the index values."""
  return set(itertools.chain.from_iterable(index.values()))


def _get_production_set_from_nested_index(
    index
):
  """Returns the productions occurring anywhere in the nested index values."""
  return set(
      itertools.chain.from_iterable(
          itertools.chain.from_iterable(sub_index.values())
          for sub_index in index.values()))


class InferenceUtilityFunctionsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('no_variable', "A[sem=(WALK+WALK)] -> 'walk' 'twice'", ('walk', 'twice'),
       ('WALK', 'WALK')),
      ('variable_on_lhs', "A[sem=(?x1+?x1)] -> 'walk' 'twice'",
       ('walk', 'twice'), ('?x1', '?x1')),
      ('variable_on_rhs', "A[sem=(WALK+WALK)] -> B[sem=?x1] 'twice'",
       ('?x1', 'twice'), ('WALK', 'WALK')),
      ('variables_on_both_sides', "A[sem=(?x1+?x1)] -> B[sem=?x1] 'twice'",
       ('?x1', 'twice'), ('?x1', '?x1')))
  def test_extract_input_tokens_and_output_tokens(self, production_string,
                                                  expected_input_tokens,
                                                  expected_output_tokens):
    production = nltk_utils.production_from_production_string(production_string)
    input_tokens, output_tokens = (
        inference._extract_input_tokens_and_output_tokens(production))

    self.assertEqual(input_tokens, expected_input_tokens)
    self.assertEqual(output_tokens, expected_output_tokens)


class BackupStatesTest(absltest.TestCase):

  def _assert_inference_engine_collections_are_not_identical(
      self, engine1, engine2):
    self.assertIsNot(engine1.source_productions, engine2.source_productions)
    self.assertIsNot(engine1.monotonic_productions,
                     engine2.monotonic_productions)
    self.assertIsNot(engine1.all_productions, engine2.all_productions)
    self.assertIsNot(engine1._productions_by_input_tokens,
                     engine2._productions_by_input_tokens)
    self.assertIsNot(engine1._productions_by_lhs_symbol,
                     engine2._productions_by_lhs_symbol)
    self.assertIsNot(engine1._productions_by_index_by_rhs_symbol,
                     engine2._productions_by_index_by_rhs_symbol)
    self.assertIsNot(engine1._source_productions_by_production,
                     engine2._source_productions_by_production)
    self.assertIsNot(engine1._productions_by_num_variables,
                     engine2._productions_by_num_variables)

  def test_copy_monotonic_engine(self):
    # Three source productions that we will add to the inference engine.
    # From these, we expect two additional composed productions to be inferred:
    # one monotonic ('jump twice') and one defeasible ('walk twice').
    production_m_1 = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_m_2 = nltk_utils.production_from_production_string(
        "V[sem=JUMP] -> 'jump'")
    production_d = nltk_utils.production_from_production_string(
        "V[sem=WALK] -> 'walk'")

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_m_1, is_monotonic=True)
    inference_engine.add_production(production_m_2, is_monotonic=True)
    inference_engine.add_production(production_d)

    copied = inference_engine.copy_monotonic_engine()
    copied_production_m = [
        production for production in copied.source_productions
        if production == production_m_1
    ][0]

    with self.subTest('should_create_new_inference_engine_instance'):
      self.assertIsNot(inference_engine, copied)

    with self.subTest('should_create_new_collections'):
      self._assert_inference_engine_collections_are_not_identical(
          inference_engine, copied)

    with self.subTest('monotonic_productions_should_have_equal_content'):
      self.assertEqual(inference_engine.monotonic_productions,
                       copied.monotonic_productions)

    with self.subTest('defeasible_productions_should_be_empty'):
      self.assertEmpty(copied.defeasible_productions)
      self.assertEqual(copied.source_productions,
                       {production_m_1, production_m_2})
      self.assertEqual(copied.all_productions, copied.monotonic_productions)

    with self.subTest(
        'productions_by_input_tokens_should_be_populated_correctly'):
      self.assertEqual(
          copied.all_productions,
          _get_production_set_from_index(copied._productions_by_input_tokens))

    with self.subTest(
        'productions_by_lhs_symbol_should_be_populated_correctly'):
      self.assertEqual(
          copied.all_productions,
          _get_production_set_from_index(copied._productions_by_lhs_symbol))

    with self.subTest(
        'productions_by_index_by_rhs_symbol_should_be_populated_correctly'):
      self.assertEqual(
          # Only the monotonic productions that have variables on the rhs.
          {production_m_1},
          _get_production_set_from_nested_index(
              copied._productions_by_index_by_rhs_symbol),
          f'copied._productions_by_index_by_rhs_symbol = '
          f'{copied._productions_by_index_by_rhs_symbol}')

    with self.subTest(
        'source_productions_by_production_should_be_populated_correctly'):
      self.assertEqual(copied.all_productions,
                       set(copied._source_productions_by_production.keys()))
      self.assertEqual(
          copied.source_productions,
          _get_production_set_from_index(
              copied._source_productions_by_production))

    with self.subTest(
        'productions_by_num_variables_should_be_populated_correctly'):
      self.assertEqual(
          copied.all_productions,
          _get_production_set_from_index(copied._productions_by_num_variables))

    with self.subTest('provenance_by_production_should_be_populated_correctly'):
      self.assertEqual(copied.all_productions,
                       set(copied.provenance_by_production.keys()))

    with self.subTest('should_not_duplicate_productions'):
      self.assertIs(production_m_1, copied_production_m)

  def test_backup_states(self):
    production_0 = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_1 = nltk_utils.production_from_production_string(
        "V[sem=WALK] -> 'walk'")

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1)

    backup = inference_engine.backup_states()
    backup_production_0 = [
        production for production in backup.source_productions
        if production == production_0
    ][0]
    backup_production_1 = [
        production for production in backup.source_productions
        if production == production_1
    ][0]

    with self.subTest('should_create_new_inference_engine_instance'):
      self.assertIsNot(inference_engine, backup)

    with self.subTest('should_create_new_collections'):
      self._assert_inference_engine_collections_are_not_identical(
          inference_engine, backup)

    with self.subTest('production_sets_should_have_equal_content'):
      self.assertEqual(inference_engine.source_productions,
                       backup.source_productions)
      self.assertEqual(inference_engine.monotonic_productions,
                       backup.monotonic_productions)
      self.assertEqual(inference_engine.all_productions, backup.all_productions)

    with self.subTest('production_indices_should_have_equal_content'):
      self.assertEqual(inference_engine._productions_by_input_tokens,
                       backup._productions_by_input_tokens)
      self.assertEqual(inference_engine._productions_by_lhs_symbol,
                       backup._productions_by_lhs_symbol)
      self.assertEqual(inference_engine._productions_by_index_by_rhs_symbol,
                       backup._productions_by_index_by_rhs_symbol)
      self.assertEqual(inference_engine._source_productions_by_production,
                       backup._source_productions_by_production)
      self.assertEqual(inference_engine._productions_by_num_variables,
                       backup._productions_by_num_variables)

    with self.subTest('should_not_duplicate_productions'):
      self.assertIs(production_0, backup_production_0)
      self.assertIs(production_1, backup_production_1)


class AddMonotonicProductionTest(parameterized.TestCase):

  def test_basic_behavior(self):
    production_string = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production = nltk_utils.production_from_production_string(production_string)
    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production, is_monotonic=True)

    expected_input_tokens = ('?x1', 'twice')

    with self.subTest('should_include_production_in_all_productions'):
      self.assertIn(production, inference_engine.all_productions)

    with self.subTest('should_include_production_in_monotonic_productions'):
      self.assertIn(production, inference_engine.monotonic_productions)

    with self.subTest('should_record_production_by_input_tokens'):
      self.assertIn(expected_input_tokens,
                    inference_engine._productions_by_input_tokens)
      self.assertSetEqual(
          inference_engine._productions_by_input_tokens[expected_input_tokens],
          {production})

  def test_should_accept_new_production_without_inconsistency(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "V[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=True)

    composed_production_string = "S[sem=(WALK+WALK)] -> 'walk' 'twice'"
    composed_production = nltk_utils.production_from_production_string(
        composed_production_string)

    expected_input_tokens = ('walk', 'twice')

    with self.subTest('should_include_productions_in_monotonic_productions'):
      self.assertContainsSubset([production_0, production_1],
                                inference_engine.monotonic_productions)

    with self.subTest(
        'should_include_composed_production_in_monotonic_productions'):
      self.assertIn(composed_production, inference_engine.monotonic_productions)

    with self.subTest('should_record_composed_production_by_input_tokens'):
      self.assertIn(expected_input_tokens,
                    inference_engine._productions_by_input_tokens)
      self.assertSetEqual(
          inference_engine._productions_by_input_tokens[expected_input_tokens],
          {composed_production})

  def test_should_reject_production_contradicting_monotonic_production(self):
    production_string_0 = "V[sem=RUN] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    contradicting_production_string = "V[sem=WALK] -> 'walk'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)

    original_inference_engine = copy.deepcopy(inference_engine)

    with self.subTest('should_raise_error'):
      with self.assertRaisesRegex(
          inference.InconsistencyError,
          'Adding production would cause inconsistency'):
        inference_engine.add_production(
            contradicting_production, is_monotonic=True)

    with self.subTest('should_not_alter_states'):
      self.assertEqual(inference_engine, original_inference_engine)

  @parameterized.named_parameters(
      ('rule_with_variable_monotonic', "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'",
       "V[sem=WALK] -> 'walk'", "S[sem=WALK] -> 'walk' 'twice'"),
      ('rule_with_variable_contradicting', "V[sem=JUMP] -> 'jump'",
       "S[sem=(JUMP+JUMP+JUMP)] -> 'jump' 'twice'",
       "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"))
  def test_should_reject_production_contradicting_derived_production(
      self, production_string_0, production_string_1,
      contradicting_production_string):
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=True)

    original_inference_engine = copy.deepcopy(inference_engine)

    with self.subTest('should_raise_error'):
      with self.assertRaisesRegex(
          inference.InconsistencyError,
          'Adding production would cause inconsistency'):
        inference_engine.add_production(
            contradicting_production, is_monotonic=True)

    with self.subTest('should_not_alter_states'):
      self.assertEqual(inference_engine, original_inference_engine)


class AddDefeasibleProductionTest(absltest.TestCase):

  def test_basic_behavior(self):
    production_string = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production = nltk_utils.production_from_production_string(production_string)
    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production)

    expected_input_tokens = ('?x1', 'twice')

    with self.subTest('should_include_production_in_all_productions'):
      self.assertIn(production, inference_engine.all_productions)

    with self.subTest('should_not_include_production_in_monotonic_productions'):
      self.assertNotIn(production, inference_engine.monotonic_productions)

    with self.subTest('should_include_production_in_defeasible_productions'):
      self.assertIn(production, inference_engine.defeasible_productions)

    with self.subTest('should_record_production_by_input_tokens'):
      self.assertIn(expected_input_tokens,
                    inference_engine._productions_by_input_tokens)
      self.assertSetEqual(
          inference_engine._productions_by_input_tokens[expected_input_tokens],
          {production})

  def test_should_accept_new_production_without_inconsistency_with_defeasible(
      self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "V[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0)
    inference_engine.add_production(production_1)

    composed_production_string = "S[sem=(WALK+WALK)] -> 'walk' 'twice'"
    composed_production = nltk_utils.production_from_production_string(
        composed_production_string)

    expected_input_tokens = ('walk', 'twice')

    with self.subTest('should_include_productions_in_defeasible_productions'):
      self.assertContainsSubset([production_0, production_1],
                                inference_engine.defeasible_productions)

    with self.subTest(
        'should_include_composed_production_in_defeasible_productions'):
      self.assertIn(composed_production,
                    inference_engine.defeasible_productions)

    with self.subTest('should_record_composed_production_by_input_tokens'):
      self.assertIn(expected_input_tokens,
                    inference_engine._productions_by_input_tokens)
      self.assertSetEqual(
          inference_engine._productions_by_input_tokens[expected_input_tokens],
          {composed_production})

  def test_should_accept_new_production_without_inconsistency_with_monotonic(
      self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "V[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1)

    composed_production_string = "S[sem=(WALK+WALK)] -> 'walk' 'twice'"
    composed_production = nltk_utils.production_from_production_string(
        composed_production_string)

    expected_input_tokens = ('walk', 'twice')

    with self.subTest('should_include_productions_in_correct_collections'):
      self.assertIn(production_0, inference_engine.monotonic_productions)
      self.assertIn(production_1, inference_engine.defeasible_productions)

    with self.subTest(
        'should_include_composed_production_in_defeasible_productions'):
      self.assertIn(composed_production,
                    inference_engine.defeasible_productions)

    with self.subTest('should_record_composed_production_by_input_tokens'):
      self.assertIn(expected_input_tokens,
                    inference_engine._productions_by_input_tokens)
      self.assertSetEqual(
          inference_engine._productions_by_input_tokens[expected_input_tokens],
          {composed_production})

  def test_should_reject_production_contradicting_monotonic_production(self):
    production_string_0 = "V[sem=RUN] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    contradicting_production_string = "V[sem=WALK] -> 'walk'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)

    original_inference_engine = copy.deepcopy(inference_engine)

    with self.subTest('should_raise_error'):
      with self.assertRaisesRegex(
          inference.InconsistencyError,
          'Adding production would cause inconsistency'):
        inference_engine.add_production(contradicting_production)

    with self.subTest('should_not_alter_states'):
      self.assertEqual(inference_engine, original_inference_engine)

  def test_should_reject_production_contradicting_defeasible_production(self):
    production_string_0 = "V[sem=RUN] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    contradicting_production_string = "V[sem=WALK] -> 'walk'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0)

    original_inference_engine = copy.deepcopy(inference_engine)

    with self.subTest('should_raise_error'):
      with self.assertRaisesRegex(
          inference.InconsistencyError,
          'Adding production would cause inconsistency'):
        inference_engine.add_production(contradicting_production)

    with self.subTest('should_not_alter_states'):
      self.assertEqual(inference_engine, original_inference_engine)

  def test_should_reject_production_contradicting_derived_production(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "V[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    contradicting_production_string = "S[sem=WALK] -> 'walk' 'twice'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=True)

    original_inference_engine = copy.deepcopy(inference_engine)

    with self.subTest('should_raise_error'):
      with self.assertRaisesRegex(
          inference.InconsistencyError,
          'Adding production would cause inconsistency'):
        inference_engine.add_production(contradicting_production)

    with self.subTest('should_not_alter_states'):
      self.assertEqual(inference_engine, original_inference_engine)

  def test_should_ignore_existing_monotonic_production(self):
    production_string = "V[sem=RUN] -> 'walk'"
    production = nltk_utils.production_from_production_string(production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production, is_monotonic=True)

    original_inference_engine = copy.deepcopy(inference_engine)

    inference_engine.add_production(production)

    with self.subTest('should_not_alter_states'):
      self.assertEqual(
          inference_engine,
          original_inference_engine,
          msg=f'>>> {inference_engine}\n\n>>>{original_inference_engine}')

  def test_should_ignore_existing_defeasible_production(self):
    production_string = "V[sem=RUN] -> 'walk'"
    production = nltk_utils.production_from_production_string(production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production)

    original_inference_engine = copy.deepcopy(inference_engine)

    inference_engine.add_production(production)

    with self.subTest('should_not_alter_states'):
      self.assertEqual(inference_engine, original_inference_engine)

  def test_should_promote_defeasible_production_to_monotonic(self):
    # Here the productions are chosen in a way so that not only the composed
    # production would get promoted to monotonic production when a previously
    # defeasible parent is promoted to monotonic, but a production that can be
    # obtained only by further composing a composed production also needs to be
    # promoted.
    # This illustrates the need for considering all possible further
    # compositions when a composed production is promoted, in the implementation
    # of InferenceEngine._add_production.
    production_string_0 = "V[sem=WALK] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_string_2 = "C[sem=(?x1+STOP)] -> S[sem=?x1] 'stop'"
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=False)
    inference_engine.add_production(production_2, is_monotonic=True)

    composed_production_string_01 = "S[sem=(WALK+WALK)] -> 'walk' 'twice'"
    composed_production_01 = nltk_utils.production_from_production_string(
        composed_production_string_01)
    composed_production_string_12 = (
        "C[sem=(?x1+?x1+STOP)] -> V[sem=?x1] 'twice' 'stop'")
    composed_production_12 = nltk_utils.production_from_production_string(
        composed_production_string_12)

    # This production can be derived only by composing one of the composed
    # productions above with one of the originally added monotonic production.
    further_composed_production_string = (
        "C[sem=(WALK+WALK+STOP)] -> 'walk' 'twice' 'stop'")
    further_composed_production = nltk_utils.production_from_production_string(
        further_composed_production_string)

    with self.subTest(
        'should_include_composed_production_in_defeasible_productions'):
      self.assertIn(composed_production_01,
                    inference_engine.defeasible_productions)
      self.assertIn(composed_production_12,
                    inference_engine.defeasible_productions)
      self.assertIn(further_composed_production,
                    inference_engine.defeasible_productions)

    with self.subTest('should_have_correct_monotonic_productions'):
      self.assertCountEqual(inference_engine.monotonic_productions,
                            [production_0, production_2])

    # production_1 being defeasible is the reason for all the other defeasible
    # productions.  Promoting it to monotonic should promote all the composed
    # productions to monotonic.
    inference_engine.add_production(production_1, is_monotonic=True)

    with self.subTest('should_move_productions_to_monotonic_productions'):
      self.assertIn(production_1, inference_engine.monotonic_productions)

    with self.subTest('should_promote_composed_production_to_monotonic'):
      self.assertIn(composed_production_01,
                    inference_engine.monotonic_productions)
      self.assertIn(composed_production_12,
                    inference_engine.monotonic_productions)

    with self.subTest(
        'should_promote_further_composed_production_to_monotonic'):
      self.assertIn(further_composed_production,
                    inference_engine.monotonic_productions)


class ForceAddProductionTest(parameterized.TestCase):

  def test_basic_behavior(self):
    production_string = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production = nltk_utils.production_from_production_string(production_string)
    inference_engine = inference.InferenceEngine(
        track_multiple_provenances=True)
    inference_engine.force_add_production(production, is_monotonic=True)

    expected_input_tokens = ('?x1', 'twice')

    with self.subTest('should_include_production_in_all_productions'):
      self.assertIn(production, inference_engine.all_productions)

    with self.subTest('should_include_production_in_monotonic_productions'):
      self.assertIn(production, inference_engine.monotonic_productions)

    with self.subTest('should_record_production_by_input_tokens'):
      self.assertIn(expected_input_tokens,
                    inference_engine._productions_by_input_tokens)
      self.assertSetEqual(
          inference_engine._productions_by_input_tokens[expected_input_tokens],
          {production})

  def test_should_detect_inconsistency(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_string_1 = "V[sem='JUMP'] -> 'jump'"
    production_string_2 = "S[sem=('JUMP'+'JUMP'+'JUMP')] -> 'jump' 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)
    inference_engine = inference.InferenceEngine(
        track_multiple_provenances=True)
    inference_engine.force_add_production(production_0)
    inference_engine.force_add_production(production_1)
    status = inference_engine.force_add_production(production_2)
    inconsistency = status.inconsistency
    with self.subTest(
        'should_return_correct_directly_inconsistent_productions'):
      expected_existing_inconsistency_string = (
          "S[sem=('JUMP'+'JUMP')] -> 'jump' 'twice'")
      expected_existing_inconsistency = (
          nltk_utils.production_from_production_string(
              expected_existing_inconsistency_string))
      expected_incoming_inconsistency = production_2
      self.assertEqual(expected_existing_inconsistency,
                       inconsistency.existing_inconsistency)
      self.assertEqual(expected_incoming_inconsistency,
                       inconsistency.incoming_inconsistency)

    with self.subTest('should_return_correct_inconsistency_source'):
      self.assertSetEqual({production_0, production_1},
                          inconsistency.existing_inconsistency_source)
      self.assertSetEqual({production_2},
                          inconsistency.incoming_inconsistency_source)
    with self.subTest('should_not_detect_implication'):
      self.assertIsNone(status.implication)

  @parameterized.named_parameters(('monotonic_case', True),
                                  ('defeasible_case', False))
  def test_should_detect_implication(self, is_monotonic):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_string_1 = "V[sem='JUMP'] -> 'jump'"
    production_string_2 = "S[sem=('JUMP'+'JUMP')] -> 'jump' 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)
    inference_engine = inference.InferenceEngine()
    inference_engine.force_add_production(
        production_0, is_monotonic=is_monotonic)
    inference_engine.force_add_production(
        production_1, is_monotonic=is_monotonic)
    status = inference_engine.force_add_production(
        production_2, is_monotonic=is_monotonic)
    self.assertIsNotNone(status.implication)
    implication = status.implication
    with self.subTest('should_record_implied_production'):
      expected_production = production_2
      self.assertEqual(expected_production, implication.production)
    with self.subTest('should_return_implication_source'):
      self.assertSetEqual({production_0, production_1},
                          implication.source_productions)
    with self.subTest('should_detect_the_correct_type'):
      expected_type = cl.Qualifier.M if is_monotonic else cl.Qualifier.D
      self.assertEqual(status.implication.type, expected_type)
    with self.subTest('should_not_detect_inconsistency'):
      self.assertIsNone(status.inconsistency)

  def test_should_detect_multiple_inconsistencies(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_string_1 = "V[sem='JUMP'] -> 'jump'"
    production_string_2 = "V[sem='RUN'] -> 'run'"
    production_string_3 = "S[sem=('RUN'+'RUN'+'RUN')] -> 'run' 'twice'"
    production_string_4 = "S[sem=('JUMP'+'JUMP'+'JUMP')] -> 'jump' 'twice'"
    production_string_5 = "S[sem=('JUMP'+'JUMP')] -> 'jump' 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)
    production_3 = nltk_utils.production_from_production_string(
        production_string_3)
    production_4 = nltk_utils.production_from_production_string(
        production_string_4)
    production_5 = nltk_utils.production_from_production_string(
        production_string_5)
    inference_engine = inference.InferenceEngine()
    inference_engine.force_add_production(production_0)
    inference_engine.force_add_production(production_1)
    inference_engine.force_add_production(production_2)
    inference_engine.force_add_production(production_3)
    status_4 = inference_engine.force_add_production(production_4)
    status_5 = inference_engine.force_add_production(production_5)

    with self.subTest('should_detect_the_second_inconsistency'):
      self.assertIsNotNone(status_4.inconsistency)

    with self.subTest('should_detect_both_implication_and_inconsistency'):
      # The number of detected inconsistencies is 2 because production 5 has
      # two provenances, itself and the composition of production 0 and 1.
      self.assertIsNotNone(status_5.inconsistency)
      self.assertIsNotNone(status_5.implication)

    with self.subTest('should_have_correct_productions_by_input_tokens'):
      input_tokens = ('run', 'twice')
      expected_composed_production_string = (
          "S[sem=('RUN'+'RUN')] -> 'run' 'twice'")
      expected_composed_production = (
          nltk_utils.production_from_production_string(
              expected_composed_production_string))
      expected_productions = {production_3, expected_composed_production}
      self.assertSetEqual(
          inference_engine._productions_by_input_tokens[input_tokens],
          expected_productions)

  def test_should_track_multiple_provenances_with_known_intermediate_production(
      self):
    # Tests the scenario where an inference engine inferred a production p
    # but now we discovered a new provenance to p (without directly adding p to
    # the inference engine). This is a special case and is not covered by the
    # other test cases.
    production_0 = nltk_utils.production_from_production_string(
        "S[sem=(?x2+?x1)] -> V[sem=?x1] 'after' U[sem=?x2]")
    production_1 = nltk_utils.production_from_production_string(
        "V[sem=('JUMP'+'JUMP')] -> 'jump' 'twice'")
    production_2 = nltk_utils.production_from_production_string(
        "U[sem='WALK'] -> 'walk'")
    production_3 = nltk_utils.production_from_production_string(
        "V[sem=(?x1+?x1)] -> U[sem=?x1] 'twice'")
    production_4 = nltk_utils.production_from_production_string(
        "U[sem='JUMP'] -> 'jump'")
    # The intermediate production is
    # S[sem=(?x1+'JUMP'+'JUMP')] -> 'jump' 'twice' 'after' U[sem=?x1]
    # which can be either reached by composing production_1 into production_0
    # or by composing production_3 into production_0 (index 0) and then
    # composing production_4 into the new production.

    target_production = nltk_utils.production_from_production_string(
        "S[sem=('WALK'+'JUMP'+'JUMP')] -> 'jump' 'twice' 'after' 'walk'")

    inference_engine = inference.InferenceEngine(
        track_multiple_provenances=True)
    inference_engine.force_add_production(production_0)
    inference_engine.force_add_production(production_1)
    inference_engine.force_add_production(production_2)
    inference_engine.force_add_production(production_3)
    inference_engine.force_add_production(production_4)

    expected_target_provenances = [
        production_composition.ProductionProvenance(
            source=production_0,
            compositions=((production_1, 0), (production_2, 3))),
        production_composition.ProductionProvenance(
            source=production_0,
            compositions=((production_2, 2), (production_1, 0))),
        production_composition.ProductionProvenance(
            source=production_0,
            compositions=((production_3, 0), (production_4, 0),
                          (production_2, 3))),
        production_composition.ProductionProvenance(
            source=production_0,
            compositions=((production_3, 0), (production_2, 3),
                          (production_4, 0))),
        production_composition.ProductionProvenance(
            source=production_0,
            compositions=((production_2, 2), (production_3, 0),
                          (production_4, 0))),
    ]

    self.assertCountEqual(
        expected_target_provenances,
        inference_engine.provenances_by_production[target_production])

  def test_should_track_multiple_provenances_single_source(self):
    # Source productions
    production_0 = nltk_utils.production_from_production_string(
        "V[sem=(?x1+?x2)] -> U[sem=?x1] 'and' U[sem=?x2]")
    production_1 = nltk_utils.production_from_production_string(
        "U[sem='JUMP'] -> 'jump'")
    production_2 = nltk_utils.production_from_production_string(
        "U[sem='WALK'] -> 'walk'")

    # Source productions that are also inferred
    production_3 = nltk_utils.production_from_production_string(
        "V[sem=('JUMP'+?x1)] -> 'jump' 'and' U[sem=?x1]")

    # Inferred productions
    production_4 = nltk_utils.production_from_production_string(
        "V[sem=(?x1+'WALK')] -> U[sem=?x1] 'and' 'walk'")
    production_5 = nltk_utils.production_from_production_string(
        "V[sem=('JUMP'+'WALK')] -> 'jump' 'and' 'walk'")

    inference_engine = inference.InferenceEngine(
        track_multiple_provenances=True)
    inference_engine.force_add_production(production_0)
    inference_engine.force_add_production(production_3)
    inference_engine.force_add_production(production_1)
    inference_engine.force_add_production(production_2)

    # Source productions have themselves as provenance.
    provenances_0 = [
        production_composition.ProductionProvenance(source=production_0)
    ]
    provenances_1 = [
        production_composition.ProductionProvenance(source=production_1)
    ]
    provenances_2 = [
        production_composition.ProductionProvenance(source=production_2)
    ]

    # Source productions that are also inferred have multiple provenances.
    provenances_3 = [
        production_composition.ProductionProvenance(source=production_3),
        production_composition.ProductionProvenance(
            source=production_0, compositions=((production_1, 0),)),
    ]

    # Inferred production with only a single provenance.
    provenances_4 = [
        production_composition.ProductionProvenance(
            source=production_0, compositions=((production_2, 2),))
    ]

    # Inferred production with multiple provenances.
    provenances_5 = [
        production_composition.ProductionProvenance(
            source=production_0,
            compositions=((production_1, 0), (production_2, 2))),
        production_composition.ProductionProvenance(
            source=production_0,
            compositions=((production_2, 2), (production_1, 0))),
        production_composition.ProductionProvenance(
            source=production_3, compositions=((production_2, 2),)),
    ]

    with self.subTest('single_provenance'):
      self.assertCountEqual(
          provenances_0,
          inference_engine.provenances_by_production[production_0])
      self.assertCountEqual(
          provenances_1,
          inference_engine.provenances_by_production[production_1])
      self.assertCountEqual(
          provenances_2,
          inference_engine.provenances_by_production[production_2])
      self.assertCountEqual(
          provenances_4,
          inference_engine.provenances_by_production[production_4])

    with self.subTest('multi_provenances_naive_case'):
      self.assertCountEqual(
          provenances_3,
          inference_engine.provenances_by_production[production_3])

    with self.subTest('multi_provenance_with_multiple_intermediate_paths'):
      self.assertCountEqual(
          provenances_5,
          inference_engine.provenances_by_production[production_5])

  def test_should_track_multiple_provenances_multiple_sources(self):
    source_0 = nltk_utils.production_from_production_string(
        "V[sem=(?x1+?x2)] -> U[sem=?x1] 'and' U[sem=?x2]")
    source_1 = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x1)] -> U[sem=?x1] 'and' U[sem=?x2]")
    production_0 = nltk_utils.production_from_production_string(
        "U[sem='JUMP'] -> 'jump'")
    production_1 = nltk_utils.production_from_production_string(
        "V[sem=('JUMP'+'JUMP')] -> 'jump' 'and' 'jump'")

    inference_engine = inference.InferenceEngine(
        track_multiple_provenances=True)
    inference_engine.force_add_production(source_0)
    inference_engine.force_add_production(source_1)
    inference_engine.force_add_production(production_0)

    provenances_0 = [
        production_composition.ProductionProvenance(source=production_0)
    ]
    provenances_1 = [
        production_composition.ProductionProvenance(
            source=source_0,
            compositions=((production_0, 0), (production_0, 2))),
        production_composition.ProductionProvenance(
            source=source_0,
            compositions=((production_0, 2), (production_0, 0))),
        production_composition.ProductionProvenance(
            source=source_1,
            compositions=((production_0, 2), (production_0, 0))),
        production_composition.ProductionProvenance(
            source=source_1,
            compositions=((production_0, 0), (production_0, 2)))
    ]

    with self.subTest('single_provenance'):
      self.assertCountEqual(
          provenances_0,
          inference_engine.provenances_by_production[production_0])

    with self.subTest('multi_source_provenances'):
      self.assertCountEqual(
          provenances_1,
          inference_engine.provenances_by_production[production_1])

  def test_exhaustive_search(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_string_1 = "V[sem='JUMP'] -> 'jump'"
    production_string_2 = "V[sem='WALK'] -> 'jump'"
    production_string_3 = "S[sem=('JUMP'+'JUMP')] -> 'jump' 'twice'"
    production_string_4 = "S[sem=('WALK'+'WALK')] -> 'jump' 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)
    production_3 = nltk_utils.production_from_production_string(
        production_string_3)
    production_4 = nltk_utils.production_from_production_string(
        production_string_4)

    expected_composed_productions = {production_3, production_4}
    inference_engine = inference.InferenceEngine()
    inference_engine.force_add_production(production_0, is_monotonic=True)
    inference_engine.force_add_production(production_1, is_monotonic=True)
    inference_engine.force_add_production(production_2, is_monotonic=True)
    self.assertContainsSubset(expected_composed_productions,
                              inference_engine.all_productions)


class InconsistencyIfProductionAddedTest(absltest.TestCase):

  def test_should_detect_contradicting_monotonic_production(self):
    production_string_0 = "V[sem=RUN] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    contradicting_production_string = "V[sem=WALK] -> 'walk'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)

    expected_type = cl.Qualifier.M

    inconsistency = inference_engine.inconsistency_if_production_added(
        contradicting_production)

    self.assertIsNotNone(inconsistency)
    self.assertEqual(expected_type, inconsistency.type)

  def test_should_detect_contradicting_defeasible_production(self):
    production_string_0 = "V[sem=RUN] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    contradicting_production_string = "V[sem=WALK] -> 'walk'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=False)

    expected_type = cl.Qualifier.D

    inconsistency = inference_engine.inconsistency_if_production_added(
        contradicting_production)

    self.assertIsNotNone(inconsistency)
    self.assertEqual(expected_type, inconsistency.type)

  def test_should_detect_derived_contradicting_monotonic_production(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=WALK] -> 'walk' 'twice'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=True)

    # This production does not contradict with productions already in the
    # inference engine, but one of its derived productions does.
    contradicting_production_string = "V[sem=WALK] -> 'walk'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    expected_type = cl.Qualifier.M
    inconsistency = inference_engine.inconsistency_if_production_added(
        contradicting_production)

    self.assertIsNotNone(inconsistency)
    self.assertEqual(expected_type, inconsistency.type)

  def test_should_detect_derived_contradicting_defeasible_production(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=WALK] -> 'walk' 'twice'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=False)

    # This production does not contradict with productions already in the
    # inference engine, but one of its derived productions does.
    contradicting_production_string = "V[sem=WALK] -> 'walk'"
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    expected_type = cl.Qualifier.D
    inconsistency = inference_engine.inconsistency_if_production_added(
        contradicting_production)

    self.assertIsNotNone(inconsistency)
    self.assertEqual(expected_type, inconsistency.type)

  def test_should_detect_consistent_production(self):
    production_string_0 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=(WALK+WALK)] -> 'walk' 'twice'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=False)

    consistent_production_string = "V[sem=WALK] -> 'walk'"
    consistent_production = nltk_utils.production_from_production_string(
        consistent_production_string)

    self.assertIsNone(
        inference_engine.inconsistency_if_production_added(
            consistent_production))

  def test_should_detect_monotonically_inconsistent_production(self):
    production_string_0 = "V[sem=LOOK] -> 'run'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "D[sem=LTURN] -> 'right'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_string_2 = (
        "S[sem=(?x1+?x2+?x1)] -> V[sem=?x1] 'around' D[sem=?x2]")
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0, is_monotonic=True)
    inference_engine.add_production(production_1, is_monotonic=True)
    inference_engine.add_production(production_2, is_monotonic=False)

    composed_production_string = (
        "S[sem=(LOOK+LTURN+LOOK)] -> 'run' 'around' 'right'")
    composed_production = nltk_utils.production_from_production_string(
        composed_production_string)

    # This production has a defeasible inconsistency with production_2, and its
    # monotonic composition "S[sem=(LOOK+LOOK+LOOK)] -> 'run' 'around' 'right'"
    # contradicts composed_production.
    contradicting_production_string = (
        "S[sem=(?x1+?x1+?x1)] -> V[sem=?x1] 'around' D[sem=?x2]")
    contradicting_production = nltk_utils.production_from_production_string(
        contradicting_production_string)

    with self.subTest(
        'should_contain_composed_production_in_correct_collection'):
      self.assertIn(composed_production,
                    inference_engine.defeasible_productions)

    # The inconsistency is defeasible, between contradicting_production and
    # production_2 (which is defeasible in the inference engine).
    with self.subTest('should_detect_inconsistency'):
      inconsistency = inference_engine.inconsistency_if_production_added(
          contradicting_production)

      self.assertIsNotNone(inconsistency)
      self.assertEqual(inconsistency.type, cl.Qualifier.D)

    # Now we promote the composed production to monotonic.
    inference_engine.add_production(composed_production, is_monotonic=True)

    with self.subTest('should_promote_production_to_monotonic'):
      self.assertIn(composed_production, inference_engine.monotonic_productions)

    # The inconsistency should also be promoted to monotonic now:  Although
    # contradicting_production itself directly contradicts a defeasible
    # production, when it is added to the inference engine's monotonic
    # productions, we can monotonically derive a production that contradicts
    # with composed_production (which is now monotonic), specifically:
    # "S[sem=(LOOK+LOOK+LOOK)] -> 'run' 'around' 'right'".
    with self.subTest('should_detect_monotonic_inconsistency'):
      inconsistency = inference_engine.inconsistency_if_production_added(
          contradicting_production)

      self.assertIsNotNone(inconsistency)
      self.assertEqual(inconsistency.type, cl.Qualifier.M)


class SourceProductionsTrackingTest(absltest.TestCase):

  def test_should_record_source_productions(self):
    production_string_0 = "V[sem=WALK] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_string_2 = "C[sem=(?x1+JUMP)] -> S[sem=?x1] 'and' 'jump'"
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0)
    inference_engine.add_production(production_1)
    inference_engine.add_production(production_2)

    composed_production_string = "S[sem=(WALK+WALK)] -> 'walk' 'twice'"
    composed_production = nltk_utils.production_from_production_string(
        composed_production_string)

    further_composed_production_string = (
        "C[sem=(WALK+WALK+JUMP)] -> 'walk' 'twice' 'and' 'jump'")
    further_composed_production = nltk_utils.production_from_production_string(
        further_composed_production_string)

    # Manually added productions are source productions.
    with self.subTest(
        'should_consider_manually_added_productions_as_source_production'):
      self.assertIn(production_0, inference_engine.source_productions)
      self.assertIn(production_1, inference_engine.source_productions)
      self.assertIn(production_2, inference_engine.source_productions)
      self.assertEqual(
          inference_engine.get_source_productions(production_0), [production_0])
      self.assertEqual(
          inference_engine.get_source_productions(production_1), [production_1])
      self.assertEqual(
          inference_engine.get_source_productions(production_2), [production_2])

    with self.subTest('should_not_include_compositions_as_source_productions'):
      self.assertNotIn(composed_production, inference_engine.source_productions)
      self.assertNotIn(further_composed_production,
                       inference_engine.source_productions)

    with self.subTest('should_record_parents_as_source_productions'):
      self.assertCountEqual(
          inference_engine.get_source_productions(composed_production),
          [production_0, production_1])

    # composed_production has its own source productions, so its descendants
    # should include those in their source productions.
    with self.subTest(
        'should_recursively_include_source_productions_of_parents'):
      self.assertCountEqual(
          inference_engine.get_source_productions(further_composed_production),
          [production_0, production_1, production_2])

    # If we now add composed_production to the inference engine, its and its
    # descendants' source_productions lists should be updated.
    inference_engine.add_production(composed_production)

    with self.subTest('should_record_added_productions_as_source_productions'):
      self.assertIn(composed_production, inference_engine.source_productions)
      self.assertEqual(
          inference_engine.get_source_productions(composed_production),
          [composed_production])

    with self.subTest('should_update_source_productions_of_descendants'):
      self.assertCountEqual(
          inference_engine.get_source_productions(further_composed_production),
          [composed_production, production_2])


class ProductionsByNumVariablesTest(absltest.TestCase):

  def test_should_record_productions_by_num_variables(self):
    production_string_0 = "V[sem=WALK] -> 'walk'"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_string_2 = "S[sem=(?x1+?x1+?x1)] -> V[sem=?x1] 'thrice'"
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)

    composed_production_string_01 = "S[sem=(WALK, WALK)] -> 'walk' 'twice'"
    composed_production_01 = nltk_utils.production_from_production_string(
        composed_production_string_01)
    composed_production_string_02 = (
        "S[sem=(WALK, WALK, WALK)] -> 'walk' 'thrice'")
    composed_production_02 = nltk_utils.production_from_production_string(
        composed_production_string_02)

    inference_engine = inference.InferenceEngine()
    inference_engine.add_production(production_0)
    inference_engine.add_production(production_1)
    inference_engine.add_production(production_2)

    with self.subTest(
        'should_record_prductions_by_num_variables_including_descendants'):
      self.assertEqual(
          inference_engine.get_productions_of_num_variables(0),
          {production_0, composed_production_01, composed_production_02})
      self.assertEqual(
          inference_engine.get_productions_of_num_variables(1),
          {production_1, production_2})
      self.assertEqual(
          inference_engine.get_productions_of_num_variables(2), set())


if __name__ == '__main__':
  absltest.main()
