# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import itertools
import traceback
from typing import Any, Mapping, Optional, TypeVar

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import grammar_generation as gg
from conceptual_learning.cscan import grammar_schema as gs
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import test_utils

T = TypeVar('T')


def _create_grammar_options_contradicting_scan():
  """Returns a GrammarOptions with contents that contradict the SCAN grammar.

  Or more precisely, with contents that contradict the GrammarSchema returned
  by test_utils.get_grammar_schema_for_scan_finite_nye_standardized().

  Invoking GrammarGenerator with these options together with a partially-
  populated template based on the above GrammarSchema allows us to verify
  precisely which aspects of the generated GrammarSchema were taken from the
  template vs. generated in accordance with the GrammarOptions.
  """
  return inputs.GrammarOptions(
      # scan_finite_nye_standardized uses the SCAN vocabularies.
      input_vocabulary=inputs.INPUT_VOCABULARY_SIZE_100,
      output_vocabulary=inputs.OUTPUT_VOCABULARY_SIZE_100,
      # scan_finite_nye_standardized has 7 primitives.
      num_primitives=4,
      # scan_finite_nye_standardized has 4 levels.
      num_precedence_levels=2,
      # scan_finite_nye_standardized has 1 to 2 categories per level.
      min_num_categories_per_level=3,
      max_num_categories_per_level=3,
      # scan_finite_nye_standardized has 0 to 2 functions per level.
      min_num_functions_per_level=3,
      max_num_functions_per_level=3,
      # scan_finite_nye_standardized has 1 to 2 args per rule.
      min_num_args=3,
      max_num_args=3,
      # scan_finite_nye_standardized has 0 to 1 postfix args per rule.
      min_num_postfix_args=2,
      max_num_postfix_args=2,
      # scan_finite_nye_standardized always has concat and pass-through rules.
      prob_pass_through_rule=0,
      prob_concat_rule=0,
      # scan_finite_nye_standardized has maximum of 4 repetitions per token.
      max_repetitions_per_token_in_output_sequence=1,
      # scan_finite_nye_standardized has no raw tokens in output sequence.
      min_unique_raw_tokens_in_output_sequence=1,
      max_unique_raw_tokens_in_output_sequence=1,
  )


class GenerateGrammarTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  def test_should_return_a_non_empty_grammar(self):
    generator = gg.GrammarGenerator(rng=self.rng)
    grammar = generator.generate_grammar()
    self.assertNotEmpty(grammar.productions())

  def test_should_return_a_grammar_with_a_valid_start_symbol(self):
    generator = gg.GrammarGenerator(rng=self.rng)
    grammar = generator.generate_grammar()
    self.assertNotEmpty(grammar.start())

  def test_should_return_different_grammar_each_time(self):
    generator = gg.GrammarGenerator(rng=self.rng)
    grammar1 = generator.generate_grammar()
    grammar2 = generator.generate_grammar()
    self.assertNotEqual(grammar1, grammar2)

  def test_should_behave_deterministically_for_a_given_random_seed(self):
    # Note that this test only ensures deterministic behavior within a single
    # process. It is still possible that behavior is non-deterministic across
    # runs at different times on different machines.
    random_seed = 42
    generator1 = gg.GrammarGenerator(rng=np.random.RandomState(random_seed))
    grammar1 = generator1.generate_grammar()
    generator2 = gg.GrammarGenerator(rng=np.random.RandomState(random_seed))
    grammar2 = generator2.generate_grammar()
    self.assertEqual(str(grammar1), str(grammar2))


class GenerateGrammarSchemaTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  def _create_generator(
      self,
      options = None):
    if not options:
      options = inputs.GrammarOptions()
    return gg.GrammarGenerator(options=options, rng=self.rng)

  def _generate_grammar_schema_and_assert_valid(
      self,
      generator,
      template = None):
    try:
      schema = generator.generate_grammar_schema(template)
    except ValueError:
      self.fail(f'Exception raised in generate_grammar_schema (likely due to '
                f'failed validation): {traceback.format_exc()}')
    # In practice, the following check should be superfluous, as the generated
    # grammar should have already been validated by generate_grammar_schema
    # before returning. Including this check explicitly anyway, however, in
    # case the validation code were ever removed from generate_grammar_schema.
    self.assertTrue(
        schema.is_valid(generator.options), f'Invalid schema: {schema}')
    return schema

  def assertAllValuesBetween(self, items, min_value,
                             max_value, schema):
    if not all(min_value <= value <= max_value for value in items.values()):
      self.fail(f'Mapping unexpectedly contains value not between '
                f'"{min_value}" and "{max_value}": {items}\nin: '
                f'{schema.to_grammar_string()}')

  def test_should_always_return_a_valid_grammar(self):
    # Note that this test does much of the heavy lifting in ensuring that
    # GrammarGenerator behaves as expected, as much of the complexity of
    # GrammarGenerator's implementation is for ensuring that the invariants
    # encapsulated in GrammarSchema.validate() are satisfied by the generated
    # GrammarSchema.
    generator = self._create_generator()
    num_grammars_to_generate = 100
    logging.info('Generating %d random gramar schemas',
                 num_grammars_to_generate)
    for i in range(num_grammars_to_generate):
      schema = self._generate_grammar_schema_and_assert_valid(generator)
      logging.info('\nGrammar schema %d of %d = %s', i,
                   num_grammars_to_generate, schema.to_grammar_string())

  def test_should_return_different_grammar_each_time(self):
    generator = self._create_generator()
    schema1 = generator.generate_grammar_schema()
    schema2 = generator.generate_grammar_schema()
    self.assertNotEqual(schema1, schema2)

  def test_with_pass_through_and_concat_rules_disabled(self):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            prob_pass_through_rule=0.0,
            prob_concat_rule=0.0,
        ))
    schema = self._generate_grammar_schema_and_assert_valid(generator)
    logging.info('\nGrammar schema without concat or pass-through rules = %s',
                 schema.to_grammar_string())
    with self.subTest('should_contain_no_PassThroughRule'):
      self.assertEmpty(schema.pass_through_rules)
    with self.subTest('should_contain_no_ConcatRule'):
      self.assertIsNone(schema.concat_rule)

  @parameterized.named_parameters(
      ('low_min_and_max', 1, 1),
      ('typical_case', 2, 4),
      ('high_min_and_max', 8, 8),
  )
  def test_should_honor_min_max_primitives(self, min_value, max_value):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_num_primitives=min_value, max_num_primitives=max_value))
    schema = generator.generate_grammar_schema()
    self.assertBetween(len(schema.primitives), min_value, max_value)

  @parameterized.named_parameters(
      ('low_min_and_max', 1, 1),
      ('typical_case', 2, 3),
      ('high_min_and_max', 4, 4),
  )
  def test_should_honor_min_max_num_precedence_levels(self, min_value,
                                                      max_value):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_num_precedence_levels=min_value,
            max_num_precedence_levels=max_value,
        ))
    schema = generator.generate_grammar_schema()
    self.assertBetween(schema.get_max_level(), min_value, max_value)

  @parameterized.named_parameters(
      ('low_min_and_max', 1, 1),
      ('typical_case', 2, 3),
      ('high_min_and_max', 4, 4),
  )
  def test_should_honor_min_max_num_categories_per_level_except_last_level(
      self, min_value, max_value):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_num_categories_per_level=min_value,
            max_num_categories_per_level=max_value,
            possible_categories_by_level=inputs
            .POSSIBLE_CATEGORIES_BY_LEVEL_8_PER_LEVEL,
        ))
    schema = generator.generate_grammar_schema()
    num_categories_by_level_except_last = {
        level: len(schema.get_rule_categories_for_level(level))
        for level in range(schema.get_max_level())
    }
    with self.subTest(
        name='levels_except_last_should_honor_min_max_num_categories'):
      self.assertAllValuesBetween(num_categories_by_level_except_last,
                                  min_value, max_value, schema)
    with self.subTest('last_level_should_always_have_one_category'):
      self.assertLen(
          schema.get_rule_categories_for_level(schema.get_max_level()), 1)

  @parameterized.named_parameters(
      ('low_min_and_max', 1, 1),
      ('typical_case', 2, 3),
      ('high_min_and_max', 4, 4),
  )
  def test_should_honor_min_max_num_functions_per_level(self, min_value,
                                                        max_value):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_num_functions_per_level=min_value,
            max_num_functions_per_level=max_value,
            # We are limiting the number of categories as well for the purposes
            # of this test, as max_num_functions_per_level may be overridden if
            # a larger number of functions is required in order to consume all
            # of the categories that were output by the previous level.
            max_num_categories_per_level=1,
            # Using larger vocabularies to avoid running out of function tokens
            input_vocabulary=inputs.INPUT_VOCABULARY_SIZE_100,
            output_vocabulary=inputs.OUTPUT_VOCABULARY_SIZE_100,
        ))
    schema = generator.generate_grammar_schema()
    num_functions_by_level = {
        level: len(schema.functions_by_level.get(level, []))
        for level in range(1, schema.get_max_level())
    }
    self.assertAllValuesBetween(num_functions_by_level, min_value, max_value,
                                schema)

  @parameterized.named_parameters(
      ('low_min_and_max', 1, 1),
      ('typical_case', 2, 3),
      ('high_min_and_max', 4, 4),
  )
  def test_should_honor_min_max_num_args(self, min_value, max_value):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_num_args=min_value,
            max_num_args=max_value,
        ))
    schema = generator.generate_grammar_schema()
    num_args_by_rule = {
        rule.to_rule_string(): len(rule.args) for rule in
        itertools.chain.from_iterable(schema.functions_by_level.values())
    }
    self.assertAllValuesBetween(num_args_by_rule, min_value, max_value, schema)

  @parameterized.named_parameters(
      ('prefix_args_only', 0, 0, 3, 3),
      ('mixture_of_prefix_and_postfix_args', 1, 2, 3, 3),
      ('postfix_args_only', 3, 3, 3, 3),
  )
  def test_should_honor_min_max_num_postfix_args(self, min_num_postfix_args,
                                                 max_num_postfix_args,
                                                 min_num_args, max_num_args):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_num_args=min_num_args,
            max_num_args=max_num_args,
            min_num_postfix_args=min_num_postfix_args,
            max_num_postfix_args=max_num_postfix_args,
        ))
    schema = generator.generate_grammar_schema()
    num_postfix_args_by_rule = {
        rule.to_rule_string(): rule.num_postfix_args for rule in
        itertools.chain.from_iterable(schema.functions_by_level.values())
    }
    self.assertAllValuesBetween(num_postfix_args_by_rule, min_num_postfix_args,
                                max_num_postfix_args, schema)

  @parameterized.named_parameters(
      ('no_repetition_and_no_raw_tokens', 1, 0, -1, -1, -1),
      ('no_repetition_but_with_raw_tokens', 1, 2, -1, -1, -1),
      ('with_repetition_but_no_raw_tokens', 2, 0, -1, -1, -1),
      ('with_repetition_and_raw_tokens', 2, 2, -1, -1, -1),
      ('with_repetition_and_raw_tokens_and_limits', 3, 2, 3, 7, -1),
      ('with_max_cumulative_size', 8, 4, 8, 16, 20),
  )
  def test_should_honor_output_sequence_constraints(self, max_repetitions,
                                                    num_unique_raw_tokens,
                                                    max_raw_tokens, max_size,
                                                    max_cumulative_size):
    num_args = 2
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_num_args=num_args,
            max_num_args=num_args,
            max_repetitions_per_token_in_output_sequence=max_repetitions,
            min_unique_raw_tokens_in_output_sequence=num_unique_raw_tokens,
            max_unique_raw_tokens_in_output_sequence=num_unique_raw_tokens,
            max_raw_tokens_in_output_sequence=max_raw_tokens,
            max_output_sequence_size=max_size,
            max_cumulative_output_sequence_size=max_cumulative_size,
            # Using larger vocabularies to avoid running out of tokens
            input_vocabulary=inputs.INPUT_VOCABULARY_SIZE_100,
            output_vocabulary=inputs.OUTPUT_VOCABULARY_SIZE_100,
        ))
    schema = generator.generate_grammar_schema()
    min_raw_length = num_unique_raw_tokens
    max_raw_length = num_unique_raw_tokens * max_repetitions
    if max_raw_tokens >= 0:
      min_raw_length = min(min_raw_length, max_raw_tokens)
      max_raw_length = min(max_raw_length, max_raw_tokens)
    min_output_length = min_raw_length + num_args
    max_output_length = max_raw_length + num_args * max_repetitions
    if max_size >= 0:
      min_output_length = min(min_output_length, max_size)
      max_output_length = min(max_output_length, max_size)
    if max_cumulative_size >= 0:
      min_output_length = 0

    output_sequence_length_by_rule = {
        rule.to_rule_string(): len(rule.get_output_sequence()) for rule in
        itertools.chain.from_iterable(schema.functions_by_level.values())
    }
    with self.subTest('individual_sequence_constraints'):
      self.assertAllValuesBetween(output_sequence_length_by_rule,
                                  min_output_length, max_output_length, schema)

    if max_cumulative_size >= 0:
      cumulative_size = sum(output_sequence_length_by_rule.values())
      num_vars = sum(len(rule.get_args()) for rule in schema.get_all_rules())
      with self.subTest('cumulative_sequence_constraints'):
        # In the worst-case, the max_cumulative_size is exceeded by num_vars
        # because we use a greedy algorithm and need to make sure that each
        # output sequence contains all the rule's variables at least once.
        self.assertLessEqual(cumulative_size, max_cumulative_size + num_vars)

  @parameterized.named_parameters(
      ('specific_value_below_min', 3, 4, 5),
      ('specific_value_above_max', 6, 4, 5),
  )
  def test_specific_num_primitives_should_override_min_max(
      self, specific_value, min_value, max_value):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            num_primitives=specific_value,
            min_num_primitives=min_value,
            max_num_primitives=max_value,
        ))
    schema = generator.generate_grammar_schema()
    self.assertLen(schema.primitives, specific_value)

  @parameterized.named_parameters(
      ('specific_value_below_min', 1, 2, 3),
      ('specific_value_above_max', 3, 1, 2),
  )
  def test_specific_num_precedence_levels_should_override_min_max(
      self, specific_value, min_value, max_value):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            num_precedence_levels=specific_value,
            min_num_precedence_levels=min_value,
            max_num_precedence_levels=max_value,
        ))
    schema = generator.generate_grammar_schema()
    self.assertEqual(schema.get_max_level(), specific_value)

  def test_should_use_only_tokens_from_input_vocabulary(self):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            input_vocabulary=frozenset(f'i{n}' for n in range(50)),))
    schema = generator.generate_grammar_schema()
    logging.info('\nGrammar schema with custom input vocabulary = %s',
                 schema.to_grammar_string())

    input_tokens = schema.get_input_token_usage_counts().keys()
    oov_input_tokens = set(input_tokens).difference(
        set(generator.options.input_vocabulary))
    self.assertEmpty(oov_input_tokens)

  def test_should_use_only_tokens_from_output_vocabulary(self):
    generator = self._create_generator(
        options=inputs.GrammarOptions(
            min_unique_raw_tokens_in_output_sequence=1,
            max_unique_raw_tokens_in_output_sequence=1,
            output_vocabulary=frozenset(f'o{n}' for n in range(50)),
        ))
    schema = generator.generate_grammar_schema()
    logging.info('\nGrammar schema with custom output vocabulary = %s',
                 schema.to_grammar_string())

    output_tokens = schema.get_output_token_usage_counts().keys()
    oov_output_tokens = set(output_tokens).difference(
        set(generator.options.output_vocabulary))
    self.assertEmpty(oov_output_tokens)

  def test_generate_grammar_schema_identical_to_scan(self):
    # If a fully-specified GrammarSchema is passed in as the template, then
    # the generated grammar should simply be a clone of the template.
    template = test_utils.get_grammar_schema_for_scan_finite_nye_standardized()
    generator = gg.GrammarGenerator(
        options=_create_grammar_options_contradicting_scan())
    schema = generator.generate_grammar_schema(template)
    logging.info('\nGrammar schema identical to SCAN = %s',
                 schema.to_grammar_string())
    with self.subTest('generated_schema_should_be_equivalent_to_template'):
      self.assertEqual(template.to_grammar_string(), schema.to_grammar_string())
    with self.subTest(
        name='generated_schema_should_be_a_separate_copy_from_template'):
      self.assertIsNot(template, schema)

  def test_generate_grammar_schema_with_same_primitives_as_scan(self):
    # Primitives fully specified, but not non-primitives.
    scan_grammar = (
        test_utils.get_grammar_schema_for_scan_finite_nye_standardized())
    template = gs.GrammarSchema(primitives=scan_grammar.primitives)
    generator = gg.GrammarGenerator(
        options=_create_grammar_options_contradicting_scan())
    # A valid schema means that the missing pieces were properly filled in.
    schema = self._generate_grammar_schema_and_assert_valid(generator, template)
    logging.info('\nGrammar schema with same primitives as SCAN = %s',
                 schema.to_grammar_string())
    with self.subTest('primitives_should_be_a_copy_of_those_in_template'):
      self.assertIsNot(template.primitives, schema.primitives)
      self.assertEqual(template.primitives, schema.primitives)
    with self.subTest('non_primitives_should_differ_from_template'):
      self.assertNotEqual(template.functions_by_level,
                          schema.functions_by_level)
    with self.subTest(
        name='non_primitives_should_use_different_tokens_than_the_primitives'):
      self.assertAllValuesBetween(schema.get_input_token_usage_counts(), 1, 1,
                                  schema)
      self.assertAllValuesBetween(schema.get_output_token_usage_counts(), 1, 1,
                                  schema)

  def test_generate_grammar_schema_with_same_primitive_inputs_as_scan(self):
    # Here we specify only the primitive input sequences, not the outputs.
    scan_grammar = (
        test_utils.get_grammar_schema_for_scan_finite_nye_standardized())
    template = gs.GrammarSchema(
        primitives=copy.deepcopy(scan_grammar.primitives))
    for rule in template.primitives:
      rule.output_sequence = []
    generator = gg.GrammarGenerator(
        options=_create_grammar_options_contradicting_scan())
    # A valid schema means that the missing pieces were properly filled in.
    schema = self._generate_grammar_schema_and_assert_valid(generator, template)
    logging.info('\nGrammar schema with same primitive inputs as SCAN = %s',
                 schema.to_grammar_string())
    with self.subTest('primitive_inputs_should_match_template'):
      self.assertEqual([rule.input_sequence for rule in template.primitives],
                       [rule.input_sequence for rule in schema.primitives])
    with self.subTest('primitive_outputs_should_differ_from_template'):
      self.assertNotEqual(
          [rule.output_sequence for rule in template.primitives],
          [rule.output_sequence for rule in schema.primitives])

  def test_generate_grammar_schema_with_same_non_primitive_rules_as_scan(self):
    # Here we specify the non-primitive rules (and thus primitive categories),
    # but not the primitive rules themselves.
    template = test_utils.get_grammar_schema_for_scan_finite_nye_standardized()
    template.primitives = []
    generator = gg.GrammarGenerator(
        options=_create_grammar_options_contradicting_scan())
    # A valid schema means that the missing pieces were properly filled in.
    schema = self._generate_grammar_schema_and_assert_valid(generator, template)
    logging.info('\nGrammar schema with same non-primitives as SCAN = %s',
                 schema.to_grammar_string())
    with self.subTest('functions_should_match_template'):
      self.assertEqual(template.functions_by_level, schema.functions_by_level)
    with self.subTest('concat_rule_should_match_template'):
      self.assertEqual(template.concat_rule_level, schema.concat_rule_level)
      self.assertEqual(template.concat_rule, schema.concat_rule)
    with self.subTest('pass_through_rules_should_match_template'):
      self.assertEqual(template.pass_through_rules, schema.pass_through_rules)
    with self.subTest('primitives_should_differ_from_template'):
      self.assertNotEqual(template.primitives, schema.primitives)
    with self.subTest(
        name='primitives_should_use_different_tokens_than_the_non_primitives'):
      self.assertAllValuesBetween(schema.get_input_token_usage_counts(), 1, 1,
                                  schema)
      self.assertAllValuesBetween(schema.get_output_token_usage_counts(), 1, 1,
                                  schema)

  def test_generate_grammar_schema_with_same_numbers_of_rules_as_scan(self):
    # Here we specify the number of rules per level, but no rule details.
    scan_grammar = (
        test_utils.get_grammar_schema_for_scan_finite_nye_standardized())
    template = gs.GrammarSchema(
        primitives=[], functions_by_level={}, pass_through_rules={})
    template.primitives = [
        gs.PrimitiveMapping() for _ in scan_grammar.primitives
    ]
    for level, rules in scan_grammar.functions_by_level.items():
      template.functions_by_level[level] = [gs.FunctionRule() for _ in rules]
    for level in scan_grammar.pass_through_rules.keys():
      template.pass_through_rules[level] = gs.PassThroughRule()
    template.concat_rule_level = scan_grammar.concat_rule_level
    if scan_grammar.concat_rule is not None:
      template.concat_rule = gs.ConcatRule()
    generator = gg.GrammarGenerator(
        options=_create_grammar_options_contradicting_scan())
    # A valid schema means that the missing pieces were properly filled in.
    schema = self._generate_grammar_schema_and_assert_valid(generator, template)
    logging.info('\nGrammar schema with same number of rules as SCAN = %s',
                 schema.to_grammar_string())
    with self.subTest('num_primitives_should_match_template'):
      self.assertLen(schema.primitives, len(template.primitives))
    with self.subTest('num_functions_per_level_should_match_template'):
      self.assertEqual(
          {
              level: len(rules)
              for level, rules in template.functions_by_level.items()
          }, {
              level: len(rules)
              for level, rules in schema.functions_by_level.items()
          })
    with self.subTest('pass_through_rule_levels_should_match_template'):
      self.assertEqual(template.pass_through_rules.keys(),
                       schema.pass_through_rules.keys())
    with self.subTest('concat_rule_level_should_match_template'):
      self.assertEqual(template.concat_rule_level, schema.concat_rule_level)

  def test_generate_grammar_schema_with_same_shape_of_rules_as_scan(self):
    # Here we specify rules and args, but not input or output tokens.
    template = test_utils.get_grammar_schema_for_scan_finite_nye_standardized()
    for rule in template.primitives:
      rule.input_sequence = []
      rule.output_sequence = []
    for rule in itertools.chain.from_iterable(
        template.functions_by_level.values()):
      rule.function_phrase = []
    generator = gg.GrammarGenerator(
        options=_create_grammar_options_contradicting_scan())
    # A valid schema means that the missing pieces were properly filled in.
    schema = self._generate_grammar_schema_and_assert_valid(generator, template)
    logging.info('\nGrammar schema with same shape of rules as SCAN = %s',
                 schema.to_grammar_string())
    with self.subTest('num_levels_should_match_template'):
      self.assertEqual(template.get_max_level(), schema.get_max_level())
    with self.subTest('rule_args_should_match_template'):
      self.assertEqual(
          {
              level: template.get_args_for_level(level)
              for level in range(template.get_max_level())
          }, {
              level: schema.get_args_for_level(level)
              for level in range(schema.get_max_level())
          })
    with self.subTest('rule_categories_should_match_template'):
      self.assertEqual(
          {
              level: template.get_rule_categories_for_level(level)
              for level in range(template.get_max_level())
          }, {
              level: schema.get_rule_categories_for_level(level)
              for level in range(schema.get_max_level())
          })


class FixedPhraseStructureTemplateFromGrammarSchemaTest(absltest.TestCase):

  def test_fixed_phrase_structure_template_from_grammar_schema(self):
    grammar = test_utils.get_grammar_schema_for_scan_finite_nye_standardized()
    template = gg.fixed_phrase_structure_template_from_grammar_schema(grammar)
    expected_rule_strings = [
        "C[sem=''] -> S[sem=?x1] 'and' S[sem=?x2]",
        "C[sem=''] -> S[sem=?x1] 'after' S[sem=?x2]",
        'C[sem=?x1] -> S[sem=?x1]', "S[sem=''] -> V[sem=?x1] 'twice'",
        "S[sem=''] -> V[sem=?x1] 'thrice'", 'S[sem=?x1] -> V[sem=?x1]',
        "V[sem=''] -> U[sem=?x1] 'opposite' W[sem=?x2]",
        "V[sem=''] -> U[sem=?x1] 'around' W[sem=?x2]",
        'V[sem=?x1] -> D[sem=?x1]', "D[sem=''] -> U[sem=?x1] W[sem=?x2]",
        'D[sem=?x1] -> U[sem=?x1]', "U[sem=''] -> 'walk'",
        "U[sem=''] -> 'look'", "U[sem=''] -> 'run'", "U[sem=''] -> 'jump'",
        "U[sem=''] -> 'turn'", "W[sem=''] -> 'left'", "W[sem=''] -> 'right'"
    ]
    self.assertCountEqual(
        [rule.to_rule_string() for rule in template.get_all_rules()],
        expected_rule_strings)


if __name__ == '__main__':
  absltest.main()
