# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import traceback

from absl.testing import absltest
from absl.testing import parameterized
import networkx as nx
import nltk

from conceptual_learning.cscan import enums
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import production_trees

ProductionTree = production_trees.ProductionTree


class ProductionTreeTest(absltest.TestCase):

  def test_from_production_creates_empty_child_for_each_variable(self):
    production = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_tree = ProductionTree.from_production(production)
    expected_production_tree = ProductionTree(production, children=[None, None])
    self.assertEqual(str(expected_production_tree), str(production_tree))

  def test_from_production_provenance_basic(self):
    # Composed rule: 'x1 opposite left'
    production_0 = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_1 = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_0, compositions=((production_1, 2),))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    # Note that 'None' is inserted in position 0 of the `children` list, since
    # we composed the productions at position 2 (i.e., variable index 1) rather
    # than position 0 (variable index 0).
    expected_production_tree = ProductionTree(
        production_0, children=[None, ProductionTree(production_1)])
    self.assertEqual(str(expected_production_tree), str(production_tree))

  def test_from_production_provenance_nested(self):
    # Composed rule: 'x1 opposite left twice after x2'
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_opposite = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_left = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_after,
        compositions=((production_twice, 0), (production_opposite, 0),
                      (production_left, 2)))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    expected_production_tree = ProductionTree(
        production_after,
        children=[
            ProductionTree(
                production_twice,
                children=[
                    ProductionTree(
                        production_opposite,
                        children=[None, ProductionTree(production_left)])
                ]),
            None,
        ])
    self.assertEqual(str(expected_production_tree), str(production_tree))

  def test_get_composed_production_should_recover_production(self):
    # This is inspired by a similar test in `production_composition_test.py`:
    # `test_get_production_should_recover_production`. The idea is that we take
    # all the productions generatable from a simple grammar, convert their
    # production provenances into production trees, further convert those into
    # composed productions, and then verify they are the same as the productions
    # we started with.

    # Here we use a minimal grammar to save time but still capture the shape of
    # the SCAN grammar.
    grammar_string = """% start C
    C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]
    C[sem=?x1] -> S[sem=?x1]
    S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'
    S[sem=?x1] -> V[sem=?x1]
    V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]
    V[sem=?x1] -> D[sem=?x1]
    D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
    D[sem=?x1] -> U[sem=?x1]
    U[sem='WALK'] -> 'walk'
    W[sem='RTURN'] -> 'right'"""
    grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)

    provenance_by_production = production_composition.ProductionProvenanceDict()
    inference_engine = inference.InferenceEngine(
        provenance_by_production=provenance_by_production)

    # We use the inference engine to automate the process of creating all
    # possible compositions.  With the grammar above there are 693 productions
    # in total.
    for source_production in grammar.productions():
      inference_engine.add_production(source_production)

    # Now we do the round-trip conversion for each of the generated productions.
    # If this test fails, it will presumably be due to an issue in
    # ProductionTree, as the round-trip conversion between Production and
    # ProductionProvenance is already tested in production_composition_test.py.
    for production in inference_engine.all_productions:
      provenance = provenance_by_production[production]
      try:
        production_tree = ProductionTree.from_production_provenance(provenance)
      except Exception:
        self.fail(f'Exception raised when converting ProductionProvenance to '
                  f'ProductionTree:\n'
                  f'  Original production={production}\n'
                  f'  ProductionProvenance={provenance}\n'
                  f'  Exception={traceback.format_exc()}')
      try:
        recovered_production = production_tree.get_composed_production()
      except Exception:
        self.fail(
            f'Exception raised when converting ProductionTree to production:\n'
            f'  Original production={production}\n'
            f'  ProductionTree={production_tree}\n'
            f'  Exception={traceback.format_exc()}')
      self.assertEqual(str(production), str(recovered_production))

  def test_get_variable_substitutions_nested(self):
    # Composed rule: 'x1 opposite left twice after x2 opposite left twice'
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_opposite = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_left = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_tree = ProductionTree(
        production_after,
        children=[
            ProductionTree(
                production_twice,
                children=[
                    ProductionTree(
                        production_opposite,
                        children=[None, ProductionTree(production_left)])
                ]),
            ProductionTree(
                production_twice,
                children=[
                    ProductionTree(
                        production_opposite,
                        children=[None, ProductionTree(production_left)])
                ])
        ])

    # get_variable_substitutions_by_production: Each production appears as a key
    # in the outer mapping.
    variable_substitutions_by_production = {}
    production_tree.get_variable_substitutions_by_production(
        variable_substitutions_by_production)
    expected_by_production = {
        production_after: {
            'x1': {'x1 opposite left twice'},
            'x2': {'x1 opposite left twice'}
        },
        production_twice: {
            'x1': {'x1 opposite left'}
        },
        production_opposite: {
            'x1': {'x1'},
            'x2': {'left'},
        },
        production_left: {}
    }
    with self.subTest('by_production'):
      self.assertDictEqual(expected_by_production,
                           variable_substitutions_by_production)

    # get_variable_substitutions_by_rule: Same as by_production, except that the
    # keys of the outer mapping are in interpretation rule format.
    variable_substitutions_by_rule = {}
    production_tree.get_variable_substitutions_by_rule(
        variable_substitutions_by_rule, enums.RuleFormat.INTERPRETATION_RULE)
    expected_by_rule = {
        '[x1 after x2] = [x2] [x1]': {
            'x1': {'x1 opposite left twice'},
            'x2': {'x1 opposite left twice'}
        },
        '[x1 twice] = [x1] [x1]': {
            'x1': {'x1 opposite left'}
        },
        '[x1 opposite x2] = [x2] [x2] [x1]': {
            'x1': {'x1'},
            'x2': {'left'},
        },
        '[left] = LTURN': {}
    }
    with self.subTest('by_rule'):
      self.assertDictEqual(expected_by_rule, variable_substitutions_by_rule)

  def test_get_variable_substitutions_called_on_multiple_production_trees(self):
    """Merges the substitutions together.

    This is what happens when building the variable substitutions map for a
    context as a whole. One call to get_variable_substitutions_by_production
    will be made for each context example, and the resulting data structure
    will show for each source production all of the different variable
    substitutions that were illustrated for that production across all of the
    examples in the context.
    """
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_jump = nltk_utils.production_from_production_string(
        "S[sem='JUMP'] -> 'jump'")

    # Here we create two ProductionTrees that have the same top production, but
    # two different substititutions for the first variable.
    production_tree_twice_after = ProductionTree.from_production(
        production_after)
    production_tree_twice_after.children[0] = ProductionTree.from_production(
        production_twice)
    production_tree_jump_after = ProductionTree.from_production(
        production_after)
    production_tree_jump_after.children[0] = ProductionTree.from_production(
        production_jump)

    # get_variable_substitutions_by_production: Each production appears as a key
    # in the outer mapping.
    variable_substitutions_by_production = {}
    production_tree_twice_after.get_variable_substitutions_by_production(
        variable_substitutions_by_production)
    production_tree_jump_after.get_variable_substitutions_by_production(
        variable_substitutions_by_production)
    expected_by_production = {
        # This part is merged from the first and second calls.
        production_after: {
            'x1': {'x1 twice', 'jump'},
            'x2': {'x2'}
        },
        # This part comes from the first call.
        production_twice: {
            'x1': {'x1'}
        },
        # This part comes from the second call.
        production_jump: {},
    }
    self.assertDictEqual(expected_by_production,
                         variable_substitutions_by_production)

  def test_get_variable_substitutions_called_on_same_tree_multiple_times(self):
    """Idempotent because the substitutions are represented as sets."""
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_tree_after = ProductionTree.from_production(production_after)

    # Same expected result after first call and after second call.
    expected_by_production = {
        production_after: {
            'x1': {'x1'},
            'x2': {'x2'}
        },
    }

    variable_substitutions_by_production = {}
    production_tree_after.get_variable_substitutions_by_production(
        variable_substitutions_by_production)
    with self.subTest('after_first_call'):
      self.assertDictEqual(expected_by_production,
                           variable_substitutions_by_production)
    production_tree_after.get_variable_substitutions_by_production(
        variable_substitutions_by_production)
    with self.subTest('after_second_call'):
      self.assertDictEqual(expected_by_production,
                           variable_substitutions_by_production)

  def test_get_variable_substitutions_with_passthrough_rules(self):
    # Composed rule: 'jump twice'
    production_c_s = nltk_utils.production_from_production_string(
        'C[sem=?x1] -> S[sem=?x1]')
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_v_u = nltk_utils.production_from_production_string(
        'V[sem=?x1] -> U[sem=?x1]')
    production_jump = nltk_utils.production_from_production_string(
        "U[sem='JUMP'] -> 'jump'")

    production_tree = ProductionTree(
        production_c_s,
        children=[
            ProductionTree(
                production_twice,
                children=[
                    ProductionTree(
                        production_v_u,
                        children=[ProductionTree(production_jump)])
                ])
        ])

    # get_variable_substitutions_by_production: Each production appears as a key
    # in the outer mapping (including pass-through rules).
    variable_substitutions_by_production = {}
    production_tree.get_variable_substitutions_by_production(
        variable_substitutions_by_production)
    expected_by_production = {
        production_c_s: {
            'x1': {'jump twice'},
        },
        production_twice: {
            'x1': {'jump'},
        },
        production_v_u: {
            'x1': {'jump'},
        },
        production_jump: {}
    }
    with self.subTest('by_production'):
      self.assertDictEqual(expected_by_production,
                           variable_substitutions_by_production)

    # get_variable_substitutions_by_rule: Same as by_production, except that the
    # keys of the outer mapping are in interpretation rule format, and
    # pass-through rules are skipped.
    variable_substitutions_by_rule = {}
    production_tree.get_variable_substitutions_by_rule(
        variable_substitutions_by_rule, enums.RuleFormat.INTERPRETATION_RULE)
    expected_by_rule = {
        '[x1 twice] = [x1] [x1]': {
            'x1': {'jump'}
        },
        '[jump] = JUMP': {},
    }
    with self.subTest('by_rule'):
      self.assertDictEqual(expected_by_rule, variable_substitutions_by_rule)

  def test_get_outer_substitutions_nested(self):
    # Composed rule: 'x1 opposite left twice after x2 opposite left twice'
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_opposite = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_left = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_tree = ProductionTree(
        production_after,
        children=[
            ProductionTree(
                production_twice,
                children=[
                    ProductionTree(
                        production_opposite,
                        children=[None, ProductionTree(production_left)])
                ]),
            ProductionTree(
                production_twice,
                children=[
                    ProductionTree(
                        production_opposite,
                        children=[None, ProductionTree(production_left)])
                ])
        ])

    # get_outer_substitutions_by_production: Each production appears as a key.
    # Outer substitutions are determined by the input string of the first non-
    # pass-through parent production.
    outer_substitutions_by_production = {}
    production_tree.get_outer_substitutions_by_production(
        outer_substitutions_by_production)
    expected_by_production = {
        production_after: {'__'},
        # Note that there are two different outer substitutions for the two
        # occurrences of the production, even though they were from the same
        # example. Variables are normalized in the outer substitution string.
        production_twice: {
            '__ after x1 opposite left twice', 'x1 opposite left twice after __'
        },
        # Note that there is only one outer substitution here, even though the
        # production appeared twice, as the outer substitutions in both cases
        # were the same.
        production_opposite: {'__ twice'},
        # Note that there is again only one outer substitution here, as the two
        # occurrences 'x1 opposite __' and 'x2 opposite __' would both have the
        # same input string 'x1 opposite __' when variables are normalized.
        production_left: {'x1 opposite __'}
    }
    with self.subTest('by_production'):
      self.assertDictEqual(expected_by_production,
                           outer_substitutions_by_production)

    # get_outer_substitutions_by_rule: Same as by_production, except that the
    # keys are in interpretation rule format.
    outer_substitutions_by_rule = {}
    production_tree.get_outer_substitutions_by_rule(
        outer_substitutions_by_rule, enums.RuleFormat.INTERPRETATION_RULE)
    expected_by_rule = {
        '[x1 after x2] = [x2] [x1]': {'__'},
        '[x1 twice] = [x1] [x1]': {
            '__ after x1 opposite left twice', 'x1 opposite left twice after __'
        },
        '[x1 opposite x2] = [x2] [x2] [x1]': {'__ twice'},
        '[left] = LTURN': {'x1 opposite __'}
    }
    with self.subTest('by_rule'):
      self.assertDictEqual(expected_by_rule, outer_substitutions_by_rule)

  def test_get_outer_substitutions_called_on_multiple_production_trees(self):
    """Merges the substitutions together.

    This is what happens when building the outer substitutions map for a context
    as a whole. One call to get_outer_substitutions_by_production will be made
    for each context example, and the resulting data structure will show for
    each source production all of the different outer substitutions that were
    illustrated for that production across all of the examples in the context.
    """
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_jump = nltk_utils.production_from_production_string(
        "S[sem='JUMP'] -> 'jump'")

    # Here we create two ProductionTrees that have the same top production, but
    # two different substititutions for the first variable.
    production_tree_twice_after_jump = ProductionTree.from_production(
        production_after)
    production_tree_twice_after_jump.children[0] = (
        ProductionTree.from_production(production_twice))
    production_tree_twice_after_jump.children[1] = (
        ProductionTree.from_production(production_jump))
    production_tree_jump_after = ProductionTree.from_production(
        production_after)
    production_tree_jump_after.children[0] = ProductionTree.from_production(
        production_jump)

    # get_outer_substitutions_by_production: Each production appears as a key
    # in the outer mapping.
    outer_substitutions_by_production = {}
    production_tree_twice_after_jump.get_outer_substitutions_by_production(
        outer_substitutions_by_production)
    production_tree_jump_after.get_outer_substitutions_by_production(
        outer_substitutions_by_production)
    expected_by_production = {
        # These part are merged from the first and second calls.
        production_after: {'__'},
        production_jump: {'__ after x1', 'x1 twice after __'},
        # This part comes from the first call.
        production_twice: {'__ after jump'},
    }
    self.assertDictEqual(expected_by_production,
                         outer_substitutions_by_production)

  def test_get_outer_substitutions_called_on_same_tree_multiple_times(self):
    """Idempotent because the substitutions are represented as sets."""
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_tree_after = ProductionTree.from_production(production_after)

    # Same expected result after first call and after second call.
    expected_by_production = {
        production_after: {'__'},
    }

    outer_substitutions_by_production = {}
    production_tree_after.get_outer_substitutions_by_production(
        outer_substitutions_by_production)
    with self.subTest('after_first_call'):
      self.assertDictEqual(expected_by_production,
                           outer_substitutions_by_production)
    production_tree_after.get_outer_substitutions_by_production(
        outer_substitutions_by_production)
    with self.subTest('after_second_call'):
      self.assertDictEqual(expected_by_production,
                           outer_substitutions_by_production)


class EffectiveNumSubstitutionsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('no_substitutions', {}, 0),
      ('never_top_of_rule_tree', {'a __', '__ b'}, 2),
      ('top_of_rule_tree', {'__'}, production_trees.INFINITE_SUBSTITUTIONS),
      ('sometimes_top_of_rule_tree', {'a __', '__'},
       production_trees.INFINITE_SUBSTITUTIONS),
  )
  def test_get_effective_num_outer_substitutions(self, substitutions, expected):
    self.assertEqual(
        expected,
        production_trees.get_effective_num_outer_substitutions(substitutions))

  @parameterized.named_parameters(
      ('no_variables', {}, production_trees.INFINITE_SUBSTITUTIONS),
      ('single_variable_counts_substitutions', {
          'x1': {'a', 'b'}
      }, 2),
      ('multiple_variables_takes_min', {
          'x1': {'a', 'b', 'c'},
          'x2': {'d', 'e'},
      }, 2),
      ('unsubstituted_variable_treated_as_infinite_substitutions', {
          'x1': {'x1', 'a'}
      }, production_trees.INFINITE_SUBSTITUTIONS),
      ('multiple_variables_takes_min_even_when_one_is_unsubstituted', {
          'x1': {'x1'},
          'x2': {'d', 'e'},
      }, 2),
  )
  def test_get_effective_min_num_variable_substitutions(self, substitutions_map,
                                                        expected):
    self.assertEqual(
        expected,
        production_trees.get_effective_min_num_variable_substitutions(
            substitutions_map))

  def test_graph_repeated_production(self):
    # Composed rule: 'walk and walk'
    production_0 = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x2)] -> V[sem=?x1] 'and' V[sem=?x2]")
    production_1 = nltk_utils.production_from_production_string(
        "V[sem='WALK'] -> 'walk'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_0,
        # Even though the same production appears multiple times in the
        # production tree, the different occurrences need to be represented
        # by different graph nodes.
        compositions=((production_1, 0), (production_1, 2)))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    expected_graph = nx.Graph()
    expected_graph.add_nodes_from([(), (0,), (1,)])
    expected_graph.add_edges_from([((), (0,)), ((), (1,))])

    self.assertEqual(expected_graph.nodes, production_tree.graph.nodes)
    self.assertEqual(expected_graph.edges, production_tree.graph.edges)

  def test_graph_nested(self):
    # Composed rule: 'x1 opposite left twice after x2'
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_opposite = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_left = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_after,
        compositions=((production_twice, 0), (production_opposite, 0),
                      (production_left, 2)))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    expected_graph = nx.Graph()
    expected_graph.add_nodes_from([(), (0,), (0, 0), (0, 0, 1)])
    expected_graph.add_edges_from([((), (0,)), ((0,), (0, 0)),
                                   ((0, 0), (0, 0, 1))])

    self.assertEqual(expected_graph.nodes, production_tree.graph.nodes)
    self.assertEqual(expected_graph.edges, production_tree.graph.edges)

  def test_production_from_path_nested(self):
    # Composed rule: 'x1 opposite left twice after x2'
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_opposite = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_left = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_after,
        compositions=((production_twice, 0), (production_opposite, 0),
                      (production_left, 2)))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    self.assertEqual(production_tree.production_from_path(()), production_after)
    self.assertEqual(
        production_tree.production_from_path((0,)), production_twice)
    self.assertEqual(
        production_tree.production_from_path((0, 0)), production_opposite)
    self.assertEqual(
        production_tree.production_from_path((0, 0, 1)), production_left)

  def test_get_subtrees_connected_to_root_basic(self):
    # Composed rule: 'x1 opposite left'
    production_0 = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_1 = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_0, compositions=((production_1, 2),))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    self.assertCountEqual(
        production_tree.get_subtrees_containing_root(1),
        [ProductionTree.from_production(production_0)])
    self.assertCountEqual(
        production_tree.get_subtrees_containing_root(2), [production_tree])

  def test_get_subtrees_connected_to_root_nested(self):
    # Composed rule: 'x1 opposite left twice after walk'
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_opposite = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_left = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")
    production_walk = nltk_utils.production_from_production_string(
        "S[sem='WALK'] -> 'walk'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_after,
        compositions=((production_walk, 2), (production_twice, 0),
                      (production_opposite, 0), (production_left, 2)))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    production_tree_after = ProductionTree.from_production(production_after)

    production_tree_after_twice = ProductionTree.from_production_provenance(
        production_composition.ProductionProvenance(
            source=production_after, compositions=((production_twice, 0),)))
    production_tree_after_walk = ProductionTree.from_production_provenance(
        production_composition.ProductionProvenance(
            source=production_after, compositions=((production_walk, 2),)))

    production_tree_after_twice_opposite = (
        ProductionTree.from_production_provenance(
            production_composition.ProductionProvenance(
                source=production_after,
                compositions=((production_twice, 0), (production_opposite,
                                                      0)))))
    production_tree_after_walk_twice = (
        ProductionTree.from_production_provenance(
            production_composition.ProductionProvenance(
                source=production_after,
                compositions=((production_walk, 2), (production_twice, 0)))))

    self.assertCountEqual(
        production_tree.get_subtrees_containing_root(1),
        [production_tree_after])
    self.assertCountEqual(
        production_tree.get_subtrees_containing_root(2),
        [production_tree_after_twice, production_tree_after_walk])
    self.assertCountEqual(
        production_tree.get_subtrees_containing_root(3), [
            production_tree_after_twice_opposite,
            production_tree_after_walk_twice
        ])

  def test_get_all_subtrees_basic(self):
    # Composed rule: 'x1 opposite left'
    production_0 = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_1 = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_0, compositions=((production_1, 2),))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    self.assertCountEqual(
        production_tree.get_all_subtrees(1), [
            ProductionTree.from_production(production_0),
            ProductionTree.from_production(production_1)
        ])
    self.assertCountEqual(
        production_tree.get_all_subtrees(2), [production_tree])

    # The tree has only two nodes.
    self.assertCountEqual(production_tree.get_all_subtrees(3), [])

  def test_get_all_subtrees_nested(self):
    # Composed rule: 'x1 opposite left twice after walk'
    production_after = nltk_utils.production_from_production_string(
        "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]")
    production_twice = nltk_utils.production_from_production_string(
        "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'")
    production_opposite = nltk_utils.production_from_production_string(
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_left = nltk_utils.production_from_production_string(
        "W[sem='LTURN'] -> 'left'")
    production_walk = nltk_utils.production_from_production_string(
        "S[sem='WALK'] -> 'walk'")

    production_provenance = production_composition.ProductionProvenance(
        source=production_after,
        compositions=((production_walk, 2), (production_twice, 0),
                      (production_opposite, 0), (production_left, 2)))

    production_tree = ProductionTree.from_production_provenance(
        production_provenance)

    # Subtrees of size 2.
    production_tree_after_twice = ProductionTree.from_production_provenance(
        production_composition.ProductionProvenance(
            source=production_after, compositions=((production_twice, 0),)))
    production_tree_after_walk = ProductionTree.from_production_provenance(
        production_composition.ProductionProvenance(
            source=production_after, compositions=((production_walk, 2),)))
    production_tree_twice_opposite = ProductionTree.from_production_provenance(
        production_composition.ProductionProvenance(
            source=production_twice, compositions=((production_opposite, 0),)))
    production_tree_opposite_left = ProductionTree.from_production_provenance(
        production_composition.ProductionProvenance(
            source=production_opposite, compositions=((production_left, 2),)))

    # Subtrees of size 3.
    production_tree_after_twice_opposite = (
        ProductionTree.from_production_provenance(
            production_composition.ProductionProvenance(
                source=production_after,
                compositions=((production_twice, 0), (production_opposite,
                                                      0)))))
    production_tree_after_walk_twice = (
        ProductionTree.from_production_provenance(
            production_composition.ProductionProvenance(
                source=production_after,
                compositions=((production_walk, 2), (production_twice, 0)))))
    production_tree_twice_opposite_left = (
        ProductionTree.from_production_provenance(
            production_composition.ProductionProvenance(
                source=production_twice,
                compositions=((production_opposite, 0), (production_left, 2)))))

    self.assertCountEqual(
        production_tree.get_all_subtrees(1), [
            ProductionTree.from_production(production_after),
            ProductionTree.from_production(production_twice),
            ProductionTree.from_production(production_opposite),
            ProductionTree.from_production(production_left),
            ProductionTree.from_production(production_walk)
        ])
    self.assertCountEqual(
        production_tree.get_all_subtrees(2), [
            production_tree_after_twice, production_tree_after_walk,
            production_tree_twice_opposite, production_tree_opposite_left
        ])
    self.assertCountEqual(
        production_tree.get_all_subtrees(3), [
            production_tree_after_twice_opposite,
            production_tree_after_walk_twice,
            production_tree_twice_opposite_left
        ])


if __name__ == '__main__':
  absltest.main()
