# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap
import traceback
from typing import Dict, List, Optional

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import grammar_loader
from conceptual_learning.cscan import grammar_schema as gs
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import test_utils

# The following are example syntactic categories used in the tests in this file.
# The exact choice of category is not important, as long as they are valid
# categories for the given precedence level as per POSSIBLE_CATEGORIES_BY_LEVEL.
# The strings 'U', 'D', etc. have no particular semantic meaning, but are chosen
# as representative categories for the benefit of readers familiar with the
# original SCAN grammar, which uses these symbols in this precedence order.
L0_CATEGORY = inputs.POSSIBLE_CATEGORIES_BY_LEVEL_MINIMAL[0][0]  # 'U'
L0_CATEGORY_2 = inputs.POSSIBLE_CATEGORIES_BY_LEVEL_MINIMAL[0][1]  # 'W'
L1_CATEGORY = inputs.POSSIBLE_CATEGORIES_BY_LEVEL_MINIMAL[1][0]  # 'D'
L2_CATEGORY = inputs.POSSIBLE_CATEGORIES_BY_LEVEL_MINIMAL[2][0]  # 'V'
L3_CATEGORY = inputs.POSSIBLE_CATEGORIES_BY_LEVEL_MINIMAL[3][0]  # 'S'


def _create_primitive_mapping(
    input_sequence = None,
    output_sequence = None,
    category = L0_CATEGORY):
  """Creates a PrimitiveMapping with arbitrary values for test purposes."""
  if input_sequence is None:
    input_sequence = ['walk']
  if output_sequence is None:
    output_sequence = ['WALK']
  return gs.PrimitiveMapping(
      input_sequence=input_sequence,
      output_sequence=output_sequence,
      category=category)


def _create_single_arg_function_rule(
    function_phrase = None,
    category = L1_CATEGORY,
    arg_category = L0_CATEGORY,
    variable = '?x1',
    output_sequence = None):
  """Creates a FunctionRule with arbitrary values for test purposes."""
  if function_phrase is None:
    function_phrase = ['f']
  if output_sequence is None:
    output_sequence = [variable]
  return gs.FunctionRule(
      function_phrase=function_phrase,
      category=category,
      num_args=1,
      num_postfix_args=0,
      args=[
          gs.FunctionArg(variable=variable, category=arg_category),
      ],
      output_sequence=output_sequence)


def _create_two_arg_function_rule(
    function_phrase = None,
    category = L1_CATEGORY,
    arg_category = L0_CATEGORY,
    arg2_category = None,
    variable1 = '?x1',
    variable2 = '?x2',
    output_sequence = None):
  """Creates a FunctionRule with arbitrary values for test purposes."""
  if function_phrase is None:
    function_phrase = ['g']
  if arg2_category is None:
    arg2_category = arg_category
  if output_sequence is None:
    output_sequence = [variable2, variable1]
  return gs.FunctionRule(
      function_phrase=function_phrase,
      category=category,
      num_args=2,
      num_postfix_args=1,
      args=[
          gs.FunctionArg(variable=variable1, category=arg_category),
          gs.FunctionArg(variable=variable2, category=arg2_category),
      ],
      output_sequence=output_sequence)


def _create_concat_rule(category = L1_CATEGORY,
                        arg_category = L0_CATEGORY):
  return gs.ConcatRule(
      category=category,
      arg1=gs.FunctionArg(variable='?x1', category=arg_category),
      arg2=gs.FunctionArg(variable='?x2', category=arg_category),
      output_sequence=['?x2', '?x1'])


def _create_pass_through_rule(
    category = L1_CATEGORY,
    arg_category = L0_CATEGORY):
  return gs.PassThroughRule(
      category=category,
      arg=gs.FunctionArg(variable='?x1', category=arg_category))


def _create_single_level_schema(
    level_0_category = L0_CATEGORY,
    level_1_category = L1_CATEGORY,
    pass_through_rules = None,
    concat_rule_level = None,
    concat_rule = None):
  """Returns a simple valid 1-level GrammarSchema reusable across tests.

  See CreateSingleLevelSchemaTest below for the corresponding FeatureGrammar in
  fcfg syntax.

  Args:
    level_0_category: Syntactic category output by level 0.
    level_1_category: Syntactic category output by level 1.
    pass_through_rules: Optional dict of level to PassThroughRule.
    concat_rule_level: Optional level of ConcatRule.
    concat_rule: Optional ConcatRule.
  """
  if pass_through_rules is None:
    pass_through_rules = {}
  schema = gs.GrammarSchema(
      primitives=[
          gs.PrimitiveMapping(
              input_sequence=['a'],
              output_sequence=['A'],
              category=level_0_category),
      ],
      functions_by_level={
          1: [
              _create_single_arg_function_rule(
                  function_phrase=['f'],
                  category=level_1_category,
                  arg_category=level_0_category)
          ],
      },
      pass_through_rules=pass_through_rules,
      concat_rule_level=concat_rule_level,
      concat_rule=concat_rule)
  return schema


class CreateSingleLevelSchemaTest(absltest.TestCase):
  """Test illustrating the meaning of _create_single_level_schema()."""

  def test_illustrate_fcfg_format_of_valid_single_level_schema(self):
    schema = _create_single_level_schema()
    actual_grammar_string = schema.to_grammar_string()
    expected_grammar_string = textwrap.dedent("""\
        % start D
        D[sem=?x1] -> U[sem=?x1] 'f'
        U[sem='A'] -> 'a'""")
    self.assertEqual(expected_grammar_string, actual_grammar_string)


def _create_two_level_schema(
    level_0_category = L0_CATEGORY,
    level_1_category = L1_CATEGORY,
    level_2_category = L2_CATEGORY,
    pass_through_rules = None,
    concat_rule_level = None,
    concat_rule = None):
  """Returns a simple valid 2-level GrammarSchema reusable across tests.

  See CreateTwoLevelSchemaTest below for the corresponding FeatureGrammar in
  fcfg syntax.

  Args:
    level_0_category: Syntactic category output by level 0.
    level_1_category: Syntactic category output by level 1.
    level_2_category: Syntactic category output by level 2.
    pass_through_rules: Optional dict of level to PassThroughRule.
    concat_rule_level: Optional level of ConcatRule.
    concat_rule: Optional ConcatRule.
  """

  schema = _create_single_level_schema(level_0_category, level_1_category,
                                       pass_through_rules, concat_rule_level,
                                       concat_rule)
  schema.functions_by_level[2] = [
      _create_two_arg_function_rule(
          function_phrase=['g'],
          category=level_2_category,
          arg_category=level_1_category)
  ]
  return schema


class CreateTwoLevelSchemaTest(absltest.TestCase):
  """Test illustrating the meaning of _create_two_level_schema()."""

  def test_illustrate_fcfg_format_of_valid_single_level_schema(self):
    schema = _create_two_level_schema()
    actual_grammar_string = schema.to_grammar_string()
    expected_grammar_string = textwrap.dedent("""\
        % start V
        V[sem=(?x2+?x1)] -> D[sem=?x1] 'g' D[sem=?x2]
        D[sem=?x1] -> U[sem=?x1] 'f'
        U[sem='A'] -> 'a'""")
    self.assertEqual(expected_grammar_string, actual_grammar_string)


class FunctionArgTest(absltest.TestCase):

  def test_to_string_on_valid_arg(self):
    arg = gs.FunctionArg(variable='x1', category=L2_CATEGORY)
    self.assertEqual('V[sem=x1]', arg.to_string())

  def test_to_string_on_empty_should_not_throw_error(self):
    arg = gs.FunctionArg()
    self.assertEqual('None[sem=None]', arg.to_string())


class PrimitiveMappingTest(absltest.TestCase):

  def test_get_rhs_terms_empty(self):
    rule = gs.PrimitiveMapping()
    self.assertEqual([], rule.get_rhs_terms())

  def test_get_rhs_terms_single_token(self):
    rule = gs.PrimitiveMapping(input_sequence=['abc'])
    self.assertEqual(["'abc'"], rule.get_rhs_terms())

  def test_get_rhs_terms_multiple_tokens(self):
    rule = gs.PrimitiveMapping(input_sequence=['abc', 'def'])
    self.assertEqual(["'abc'", "'def'"], rule.get_rhs_terms())

  def test_to_rule_string_on_valid_rule(self):
    rule = gs.PrimitiveMapping(
        category=L0_CATEGORY, input_sequence=['abc'], output_sequence=['ABC'])
    self.assertEqual("U[sem='ABC'] -> 'abc'", rule.to_rule_string())

  def test_to_rule_string_on_empty_should_not_throw_error(self):
    rule = gs.PrimitiveMapping()
    self.assertEqual("None[sem=''] -> ", rule.to_rule_string())

  def test_args_are_always_empty(self):
    rule = gs.PrimitiveMapping(
        category=L0_CATEGORY, input_sequence=['abc'], output_sequence=['ABC'])
    self.assertEmpty(rule.get_args())
    self.assertEmpty(rule.get_arg_categories())


class FunctionRuleTest(parameterized.TestCase):

  def test_num_prefix_args_is_number_of_args_that_are_not_postfix(self):
    rule = gs.FunctionRule(num_args=5, num_postfix_args=2)
    self.assertEqual(3, rule.num_prefix_args)

  def test_num_prefix_args_is_none_when_num_args_is_not_specified(self):
    rule = gs.FunctionRule(num_postfix_args=1)
    self.assertIsNone(rule.num_prefix_args)

  def test_num_prefix_args_is_none_when_num_postfix_args_is_not_specified(self):
    rule = gs.FunctionRule(num_args=1)
    self.assertIsNone(rule.num_prefix_args)

  def test_get_function_phrase_string_empty(self):
    rule = gs.FunctionRule()
    self.assertEqual('', rule.get_function_phrase_string())

  def test_get_function_phrase_string_single_token(self):
    rule = gs.FunctionRule(function_phrase=['abc'])
    self.assertEqual('abc', rule.get_function_phrase_string())

  def test_get_function_phrase_string_multiple_tokens(self):
    rule = gs.FunctionRule(function_phrase=['abc', 'def'])
    self.assertEqual('abc def', rule.get_function_phrase_string())

  def test_get_rhs_terms_empty(self):
    rule = gs.FunctionRule()
    self.assertEqual([], rule.get_rhs_terms())

  def test_get_rhs_terms_function_phrase_only(self):
    rule = gs.FunctionRule(function_phrase=['abc', 'def'])
    self.assertEqual(["'abc'", "'def'"], rule.get_rhs_terms())

  @parameterized.named_parameters(
      ('prefix_args_appear_before_function_phrase', 2, 0,
       ['U[sem=?x1]', 'W[sem=?x2]', "'abc'", "'def'"]),
      ('postfix_args_appear_after_function_phrase', 2, 2,
       ["'abc'", "'def'", 'U[sem=?x1]', 'W[sem=?x2]']),
      ('num_postfix_args_capped_to_length_of_arg_list', 2, 3,
       ["'abc'", "'def'", 'U[sem=?x1]', 'W[sem=?x2]']),
      ('num_prefix_args_capped_to_length_of_arg_list', 3, 0,
       ['U[sem=?x1]', 'W[sem=?x2]', "'abc'", "'def'"]))
  def test_get_rhs_terms_num_prefix_postfix_args(self, num_args,
                                                 num_postfix_args,
                                                 expected_rhs_terms):
    rule = gs.FunctionRule(
        function_phrase=['abc', 'def'],
        num_args=num_args,
        num_postfix_args=num_postfix_args,
        args=[
            gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
            gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2),
        ],
    )
    self.assertEqual(expected_rhs_terms, rule.get_rhs_terms())

  def test_to_rule_string_on_valid_rule(self):
    rule = gs.FunctionRule(
        function_phrase=['abc'],
        category=L1_CATEGORY,
        num_args=3,
        num_postfix_args=1,
        args=[
            gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
            gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2),
            gs.FunctionArg(variable='?x3', category=L0_CATEGORY),
        ],
        output_sequence=['?x1', '?x3', '?x2'])
    # Prefix args should appear before the function phrase.
    # Postfix args should appear after the function phrase.
    self.assertEqual(
        "D[sem=(?x1+?x3+?x2)] -> U[sem=?x1] W[sem=?x2] 'abc' U[sem=?x3]",
        rule.to_rule_string())

  def test_to_rule_string_on_empty_should_not_throw_error(self):
    rule = gs.FunctionRule()
    self.assertEqual("None[sem=''] -> ", rule.to_rule_string())

  def test_args_on_valid_rule(self):
    args = [
        gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
        gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2),
        gs.FunctionArg(variable='?x3', category=L0_CATEGORY),
    ]
    rule = gs.FunctionRule(args=args)
    self.assertEqual(args, rule.get_args())
    self.assertCountEqual({L0_CATEGORY, L0_CATEGORY_2},
                          rule.get_arg_categories())


class PassThroughRuleTest(absltest.TestCase):

  def test_args_on_valid_rule(self):
    arg = gs.FunctionArg(variable='?x1', category=L0_CATEGORY)
    rule = gs.PassThroughRule(arg=arg)
    self.assertEqual([arg], rule.get_args())
    self.assertCountEqual({L0_CATEGORY}, rule.get_arg_categories())

  def test_to_rule_string_on_valid_rule(self):
    rule = gs.PassThroughRule(
        category=L1_CATEGORY,
        arg=gs.FunctionArg(variable='?x1', category=L0_CATEGORY))
    self.assertEqual('D[sem=?x1] -> U[sem=?x1]', rule.to_rule_string())

  def test_to_rule_string_on_empty_should_not_throw_error(self):
    rule = gs.PassThroughRule()
    self.assertEqual("None[sem=''] -> None[sem=None]", rule.to_rule_string())


class ConcatRuleTest(absltest.TestCase):

  def test_args_on_valid_rule(self):
    arg1 = gs.FunctionArg(variable='?x1', category=L0_CATEGORY)
    arg2 = gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2)
    rule = gs.ConcatRule(arg1=arg1, arg2=arg2)
    self.assertEqual([arg1, arg2], rule.get_args())
    self.assertCountEqual({L0_CATEGORY, L0_CATEGORY_2},
                          rule.get_arg_categories())

  def test_to_rule_string_on_valid_rule(self):
    rule = gs.ConcatRule(
        category=L1_CATEGORY,
        arg1=gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
        arg2=gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2),
        output_sequence=['?x2', '?x1'])
    self.assertEqual('D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]',
                     rule.to_rule_string())

  def test_to_rule_string_on_empty_should_not_throw_error(self):
    rule = gs.ConcatRule()
    self.assertEqual("None[sem=''] -> None[sem=None] None[sem=None]",
                     rule.to_rule_string())


class GrammarSchemaTest(parameterized.TestCase):
  """Tests based on simple ad-hoc grammars.

  Covers edge cases that are difficult to cover in GrammarSchemaForSCANTest.
  """

  def test_get_max_level_returns_zero_for_an_empty_schema(self):
    self.assertEqual(0, gs.GrammarSchema().get_max_level())

  @parameterized.named_parameters(
      ('PrimitiveMappings only - max_level is zero because primitives are at ' +
       'level zero', gs.GrammarSchema(primitives=[_create_primitive_mapping()]),
       0),
      ('Single level of FunctionRules (plus PrimitiveMappings)',
       _create_single_level_schema(), 1),
      ('Two levels of FunctionRules (plus PrimitiveMappings)',
       _create_two_level_schema(), 2),
      ('Two levels of FunctionRules plus a PassThroughRule at a level lower',
       _create_two_level_schema(
           level_0_category=L0_CATEGORY,
           level_1_category=L1_CATEGORY,
           pass_through_rules={
               1:
                   _create_pass_through_rule(
                       category=L1_CATEGORY, arg_category=L0_CATEGORY),
           }), 2),
      ('Two levels of FunctionRules plus a PassThroughRule at a level higher',
       _create_two_level_schema(
           level_2_category=L2_CATEGORY,
           pass_through_rules={
               3:
                   _create_pass_through_rule(
                       category=L3_CATEGORY, arg_category=L2_CATEGORY),
           }), 3),
      ('Two levels of FunctionRules plus a ConcatRule at a level lower',
       _create_two_level_schema(
           level_0_category=L0_CATEGORY,
           level_1_category=L1_CATEGORY,
           concat_rule_level=1,
           concat_rule=_create_concat_rule(
               category=L1_CATEGORY, arg_category=L0_CATEGORY)), 2),
      ('Two levels of FunctionRules plus a ConcatRule at a level higher',
       _create_two_level_schema(
           level_2_category=L2_CATEGORY,
           concat_rule_level=3,
           concat_rule=_create_concat_rule(
               category=L3_CATEGORY, arg_category=L2_CATEGORY)), 3),
  )
  def test_get_max_level_returns_the_largest_level_across_all_rule_types(
      self, schema, expected_value):
    self.assertEqual(expected_value, schema.get_max_level())

  @parameterized.named_parameters(
      ('Primitives only',
       gs.GrammarSchema(primitives=[
           _create_primitive_mapping(
               input_sequence=['walk'], category=L0_CATEGORY),
           _create_primitive_mapping(
               input_sequence=['turn'], category=L0_CATEGORY),
       ]), L0_CATEGORY),
      ('None if empty', gs.GrammarSchema(), None),
      ('None if non-unique',
       gs.GrammarSchema(primitives=[
           _create_primitive_mapping(
               input_sequence=['walk'], category=L0_CATEGORY),
           _create_primitive_mapping(
               input_sequence=['left'], category=L0_CATEGORY_2),
       ]), None),
  )
  def test_get_start_symbol(self, schema, expected_value):
    if expected_value is None:
      self.assertIsNone(schema.get_start_symbol())
    else:
      self.assertEqual(expected_value, schema.get_start_symbol())

  def test_get_all_categories_for_level_combines_arg_and_rule_categories(self):
    schema = gs.GrammarSchema(
        primitives=[
            _create_primitive_mapping(category=L0_CATEGORY),
        ],
        functions_by_level={
            1: [_create_single_arg_function_rule(arg_category=L0_CATEGORY_2)],
        })
    self.assertCountEqual({L0_CATEGORY, L0_CATEGORY_2},
                          schema.get_all_categories_for_level(
                              0, inputs.GrammarOptions()))

  def test_get_args_for_level_returns_references_to_actual_schema_content(self):
    # Verify that if we modify the returned args, the contents of the
    # original GrammarSchema change.
    schema = gs.GrammarSchema(
        functions_by_level={
            1: [
                _create_two_arg_function_rule(
                    arg_category=L0_CATEGORY, arg2_category=L0_CATEGORY_2)
            ],
        })
    options = inputs.GrammarOptions(possible_categories_by_level=inputs
                                    .POSSIBLE_CATEGORIES_BY_LEVEL_8_PER_LEVEL)
    self.assertCountEqual({L0_CATEGORY, L0_CATEGORY_2},
                          schema.get_all_categories_for_level(0, options))
    for arg in schema.get_args_for_level(1):
      arg.category = 'U1'
    self.assertCountEqual({'U1'},
                          schema.get_all_categories_for_level(0, options))

  def test_get_input_token_counts_handles_repeated_token_occurrences(self):
    schema = gs.GrammarSchema(
        primitives=[
            gs.PrimitiveMapping(input_sequence=['abc']),
            gs.PrimitiveMapping(input_sequence=['abc']),
        ],
        functions_by_level={
            1: [gs.FunctionRule(function_phrase=['abc'])],
            2: [gs.FunctionRule(function_phrase=['def', 'abc', 'def'])],
        })
    self.assertCountEqual({
        'abc': 4,
        'def': 2,
    }, schema.get_input_token_usage_counts())

  def test_get_output_token_counts_handles_repeated_token_occurrences(self):
    schema = gs.GrammarSchema(
        primitives=[
            gs.PrimitiveMapping(output_sequence=['abc']),
            gs.PrimitiveMapping(output_sequence=['abc']),
        ],
        functions_by_level={
            1: [gs.FunctionRule(output_sequence=['abc'])],
            2: [gs.FunctionRule(output_sequence=['def', 'abc', 'def'])],
        })
    self.assertCountEqual({
        'abc': 4,
        'def': 2,
    }, schema.get_output_token_usage_counts())


class GrammarSchemaForSCANTest(absltest.TestCase):
  """Tests based on the contents of the original SCAN grammar.

  This verifies that the relevant methods' behavior match what we would expect
  in a situation familiar to the main users of the GrammarSchema class.
  """

  def setUp(self):
    super().setUp()
    # All of the tests in this class act on the same example schema below, which
    # is equivalent to the SCAN grammar from scan_finite_nye_standardized.fcfg.
    self.schema = (
        test_utils.get_grammar_schema_for_scan_finite_nye_standardized())

  def test_get_args_for_level(self):
    # Level 0 consists only of primitives, so there are no args.
    self.assertEqual([], self.schema.get_args_for_level(0))
    # Level 1 contains a PassThroughRule and a ConcatRule.
    self.assertEqual([
        gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
        gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2),
        gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
    ], self.schema.get_args_for_level(1))
    # Level 2 contains two FunctionRules as well as a PassThroughRule.
    self.assertEqual([
        gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
        gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2),
        gs.FunctionArg(variable='?x1', category=L0_CATEGORY),
        gs.FunctionArg(variable='?x2', category=L0_CATEGORY_2),
        gs.FunctionArg(variable='?x1', category=L1_CATEGORY),
    ], self.schema.get_args_for_level(2))

  def test_get_all_categories_for_level(self):
    options = inputs.GrammarOptions()
    self.assertCountEqual({L0_CATEGORY, L0_CATEGORY_2},
                          self.schema.get_all_categories_for_level(0, options))
    self.assertCountEqual({L1_CATEGORY},
                          self.schema.get_all_categories_for_level(1, options))
    self.assertCountEqual({L2_CATEGORY},
                          self.schema.get_all_categories_for_level(2, options))

  def test_get_input_token_usage_counts(self):
    self.assertCountEqual(
        {
            'after': 1,
            'and': 1,
            'around': 1,
            'jump': 1,
            'left': 1,
            'look': 1,
            'opposite': 1,
            'right': 1,
            'run': 1,
            'thrice': 1,
            'turn': 1,
            'twice': 1,
            'walk': 1
        }, self.schema.get_input_token_usage_counts())

  def test_get_output_token_usage_counts(self):
    self.assertCountEqual(
        {
            '': 1,
            'JUMP': 1,
            'LTURN': 1,
            'LOOK': 1,
            'RTURN': 1,
            'RUN': 1,
            'WALK': 1
        }, self.schema.get_output_token_usage_counts())

  def test_to_grammar_string(self):
    actual_grammar_string = self.schema.to_grammar_string()
    expected_grammar_string = test_utils.strip_blank_and_comment_lines(
        grammar_loader.load_standard_grammar_string(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED))
    self.assertEqual(expected_grammar_string, actual_grammar_string)

  def test_sample_rules(self):
    rng = np.random.RandomState(42)
    sampling_options = inputs.SamplingOptions(
        num_rules=10, num_rules_min=8, num_rules_max=12, num_rules_stddev=6.0)
    self.schema.sample_rules(sampling_options, rng)
    num_passthrough_rules = len(
        list(
            filter(lambda x: isinstance(x, gs.PassThroughRule),
                   self.schema.get_all_rules())))
    num_other_rules = len(
        list(
            filter(lambda x: not isinstance(x, gs.PassThroughRule),
                   self.schema.get_all_rules())))
    with self.subTest('should_keep_all_passthrough_rules'):
      self.assertEqual(4, num_passthrough_rules)
    with self.subTest('should_honor_num_rules_max'):
      self.assertGreaterEqual(sampling_options.num_rules_max, num_other_rules)
    with self.subTest('should_honor_num_rules_min'):
      self.assertLessEqual(sampling_options.num_rules_min, num_other_rules)


class GrammarSchemaValidateValidSchemaTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('Single level', _create_single_level_schema()),
      ('Multi-level', _create_two_level_schema()),
      ('PassThroughRule',
       _create_two_level_schema(
           level_0_category=L0_CATEGORY,
           level_1_category=L1_CATEGORY,
           level_2_category=L2_CATEGORY,
           pass_through_rules={
               1:
                   _create_pass_through_rule(
                       category=L1_CATEGORY, arg_category=L0_CATEGORY),
               2:
                   _create_pass_through_rule(
                       category=L2_CATEGORY, arg_category=L1_CATEGORY),
           })),
      ('ConcatRule at level 1',
       _create_two_level_schema(
           level_0_category=L0_CATEGORY,
           level_1_category=L1_CATEGORY,
           concat_rule_level=1,
           concat_rule=_create_concat_rule(
               category=L1_CATEGORY, arg_category=L0_CATEGORY))),
      ('ConcatRule at level 2',
       _create_two_level_schema(
           level_1_category=L1_CATEGORY,
           level_2_category=L2_CATEGORY,
           concat_rule_level=2,
           concat_rule=_create_concat_rule(
               category=L2_CATEGORY, arg_category=L1_CATEGORY))),
      ('Primitives only',
       gs.GrammarSchema(primitives=[
           gs.PrimitiveMapping(
               input_sequence=['a'],
               output_sequence=['A'],
               category=L0_CATEGORY),
       ])),
      ('Primitives and ConcatRule only',
       gs.GrammarSchema(
           primitives=[
               gs.PrimitiveMapping(
                   input_sequence=['a'],
                   output_sequence=['A'],
                   category=L0_CATEGORY),
           ],
           concat_rule_level=1,
           concat_rule=_create_concat_rule(
               category=L1_CATEGORY, arg_category=L0_CATEGORY))),
  )
  def test_validate(self, schema):
    options = inputs.GrammarOptions()
    schema.validate(options)
    self.assertTrue(schema.is_valid(options))


class GrammarSchemaValidateInvalidSchemaTest(absltest.TestCase):

  def test_arg_level_equal_to_function_level(self):
    schema = _create_two_level_schema(level_1_category=L1_CATEGORY)
    schema.functions_by_level[1].append(
        _create_single_arg_function_rule(
            function_phrase=['f1'], category='D1', arg_category=L1_CATEGORY))
    options = inputs.GrammarOptions(possible_categories_by_level=inputs
                                    .POSSIBLE_CATEGORIES_BY_LEVEL_8_PER_LEVEL)
    with self.assertRaisesRegex(
        ValueError, 'Arg level must be smaller than function level'):
      schema.validate(options)

  def test_arg_level_higher_than_function_level(self):
    schema = gs.GrammarSchema(
        primitives=[
            gs.PrimitiveMapping(
                input_sequence=['a'],
                output_sequence=['A'],
                category=L0_CATEGORY),
        ],
        functions_by_level={
            1: [
                _create_single_arg_function_rule(
                    function_phrase=['f'],
                    category=L1_CATEGORY,
                    arg_category=L2_CATEGORY),
            ],
        })
    with self.assertRaisesRegex(
        ValueError, 'Arg level must be smaller than function level'):
      schema.validate(inputs.GrammarOptions())

  def test_arg_lacking_category(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].args[0].category = None
    with self.assertRaisesRegex(ValueError, 'Arg category not specified'):
      schema.validate(inputs.GrammarOptions())

  def test_arg_lacking_variable(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].args[0].variable = None
    with self.assertRaisesRegex(ValueError, 'Arg variable not specified'):
      schema.validate(inputs.GrammarOptions())

  def test_arg_list_empty(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].args = []
    schema.functions_by_level[1][0].output_sequence = []
    with self.assertRaisesRegex(ValueError, 'Arg list empty'):
      schema.validate(inputs.GrammarOptions())

  def test_arg_missing_from_output_sequence(self):
    schema = gs.GrammarSchema(
        primitives=[
            _create_primitive_mapping(),
        ],
        functions_by_level={
            1: [
                _create_two_arg_function_rule(
                    variable1='?x1', variable2='?x2', output_sequence=['?x1'])
            ],
        })
    with self.assertRaisesRegex(ValueError,
                                'Vars in output do not match those in rhs'):
      schema.validate(inputs.GrammarOptions())

  def test_category_level_mismatch(self):
    schema = gs.GrammarSchema(
        primitives=[
            _create_primitive_mapping(category=L0_CATEGORY),
        ],
        functions_by_level={
            1: [
                _create_single_arg_function_rule(
                    category=L2_CATEGORY, arg_category=L0_CATEGORY),
            ],
        })
    with self.assertRaisesRegex(ValueError, 'Category level mismatch'):
      schema.validate(inputs.GrammarOptions())

  def test_category_consumed_but_not_produced(self):
    schema = gs.GrammarSchema(
        primitives=[
            _create_primitive_mapping(category=L0_CATEGORY),
        ],
        functions_by_level={
            1: [
                _create_single_arg_function_rule(
                    function_phrase=['f'],
                    category=L1_CATEGORY,
                    arg_category=L0_CATEGORY),
                _create_single_arg_function_rule(
                    function_phrase=['g'],
                    category=L1_CATEGORY,
                    arg_category=L0_CATEGORY_2),
            ],
        })
    with self.assertRaisesRegex(
        ValueError, 'Category consumed but not produced by any rule'):
      schema.validate(inputs.GrammarOptions())

  def test_category_not_consumed(self):
    schema = _create_single_level_schema(level_0_category=L0_CATEGORY)
    schema.primitives.append(
        gs.PrimitiveMapping(
            input_sequence=['b'], output_sequence=['B'],
            category=L0_CATEGORY_2))
    with self.assertRaisesRegex(ValueError,
                                'Category not consumed by any rule'):
      schema.validate(inputs.GrammarOptions())

  def test_concat_rule_arg_of_too_low_level(self):
    schema = _create_two_level_schema(
        level_0_category=L0_CATEGORY,
        level_2_category=L2_CATEGORY,
        concat_rule_level=2,
        concat_rule=_create_concat_rule(
            category=L2_CATEGORY, arg_category=L0_CATEGORY))
    with self.assertRaisesRegex(
        ValueError, 'ConcatRule arg must be from exactly one level below'):
      schema.validate(inputs.GrammarOptions(validate_concat_rule_level=True))
    try:
      schema.validate(inputs.GrammarOptions(validate_concat_rule_level=False))
    except ValueError:
      self.fail(f'Levels of ConcatRule args should not be validated: '
                f'{traceback.format_exc()}')

  def test_concat_rule_arg_of_too_high_level(self):
    schema = _create_two_level_schema(
        level_1_category=L1_CATEGORY,
        level_2_category=L2_CATEGORY,
        concat_rule_level=1,
        concat_rule=_create_concat_rule(
            category=L1_CATEGORY, arg_category=L2_CATEGORY))
    with self.assertRaisesRegex(
        ValueError, 'Arg level must be smaller than function level'):
      schema.validate(inputs.GrammarOptions())

  def test_concat_rule_category_level_mismatch(self):
    schema = _create_two_level_schema(
        level_0_category=L0_CATEGORY,
        level_1_category=L1_CATEGORY,
        concat_rule_level=2,
        concat_rule=_create_concat_rule(
            category=L1_CATEGORY, arg_category=L0_CATEGORY))
    with self.assertRaisesRegex(ValueError, 'Category level mismatch'):
      schema.validate(inputs.GrammarOptions())

  def test_concat_rule_level_missing_but_rule_specified(self):
    schema = _create_two_level_schema(
        level_0_category=L0_CATEGORY,
        level_1_category=L1_CATEGORY,
        concat_rule_level=None,
        concat_rule=_create_concat_rule(
            category=L1_CATEGORY, arg_category=L0_CATEGORY))
    with self.assertRaisesRegex(ValueError,
                                'Concat rule level missing but rule specified'):
      schema.validate(inputs.GrammarOptions())

  def test_concat_rule_level_specified_but_rule_missing(self):
    schema = _create_two_level_schema(concat_rule_level=1, concat_rule=None)
    with self.assertRaisesRegex(ValueError,
                                'Concat rule level specified but rule missing'):
      schema.validate(inputs.GrammarOptions())

  def test_duplicate_arg(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].args.append(
        schema.functions_by_level[1][0].args[0])
    with self.assertRaisesRegex(ValueError,
                                'Duplicate variable in function args'):
      schema.validate(inputs.GrammarOptions())

  def test_duplicate_function_in_different_level(self):
    schema = gs.GrammarSchema(
        primitives=[
            _create_primitive_mapping(category=L0_CATEGORY),
        ],
        functions_by_level={
            1: [
                _create_single_arg_function_rule(
                    function_phrase=['f'],
                    category=L1_CATEGORY,
                    arg_category=L0_CATEGORY),
            ],
            2: [
                _create_two_arg_function_rule(
                    function_phrase=['f'],
                    category=L2_CATEGORY,
                    arg_category=L1_CATEGORY),
            ],
        })
    with self.assertRaisesRegex(ValueError, 'Duplicate function'):
      schema.validate(inputs.GrammarOptions())

  def test_duplicate_function_in_same_level(self):
    schema = gs.GrammarSchema(
        primitives=[
            _create_primitive_mapping(category=L0_CATEGORY),
        ],
        functions_by_level={
            1: [
                _create_single_arg_function_rule(
                    function_phrase=['f'],
                    category=L1_CATEGORY,
                    arg_category=L0_CATEGORY),
                _create_two_arg_function_rule(
                    function_phrase=['f'],
                    category=L1_CATEGORY,
                    arg_category=L0_CATEGORY),
            ],
        })
    with self.assertRaisesRegex(ValueError, 'Duplicate function'):
      schema.validate(inputs.GrammarOptions())

  def test_empty(self):
    schema = gs.GrammarSchema()
    options = inputs.GrammarOptions()
    with self.assertRaisesRegex(ValueError, 'Lacking primitives'):
      schema.validate(options)
    # Testing just once that is_valid() == false when validate() fails.
    # Not necessary to check this in every test.
    self.assertFalse(schema.is_valid(options))

  def test_empty_function_phrase(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].function_phrase = []
    with self.assertRaisesRegex(ValueError, 'Empty function phrase'):
      schema.validate(inputs.GrammarOptions())

  def test_function_lacking_category(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].category = None
    with self.assertRaisesRegex(ValueError, 'Rule lacking category'):
      schema.validate(inputs.GrammarOptions())

  def test_invalid_arg_in_output_sequence(self):
    schema = gs.GrammarSchema(
        primitives=[
            _create_primitive_mapping(category=L0_CATEGORY),
        ],
        functions_by_level={
            1: [
                _create_single_arg_function_rule(
                    variable='?x1', output_sequence=['?x2', '?x1']),
            ],
        })
    with self.assertRaisesRegex(ValueError,
                                'Vars in output do not match those in rhs'):
      schema.validate(inputs.GrammarOptions())

  def test_lacking_primitives(self):
    schema = _create_single_level_schema()
    schema.primitives = []
    with self.assertRaisesRegex(ValueError, 'Lacking primitives'):
      schema.validate(inputs.GrammarOptions())

  def test_num_args_mismatch(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].num_args += 1
    with self.assertRaisesRegex(ValueError, 'Mismatch in number of args'):
      schema.validate(inputs.GrammarOptions())

  def test_num_args_not_specified(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].num_args = None
    with self.assertRaisesRegex(ValueError, 'Num_args not specified'):
      schema.validate(inputs.GrammarOptions())

  def test_num_postfix_args_not_specified(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].num_postfix_args = None
    with self.assertRaisesRegex(ValueError, 'Num_postfix_args not specified'):
      schema.validate(inputs.GrammarOptions())

  def test_num_postfix_args_too_large(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][
        0].num_postfix_args = schema.functions_by_level[1][0].num_args + 1
    with self.assertRaisesRegex(ValueError,
                                'Num_postfix_args cannot exceed num_args'):
      schema.validate(inputs.GrammarOptions())

  def test_num_postfix_args_too_small(self):
    schema = _create_single_level_schema()
    schema.functions_by_level[1][0].num_postfix_args = -1
    with self.assertRaisesRegex(ValueError,
                                'Num_postfix_args cannot be negative'):
      schema.validate(inputs.GrammarOptions())

  def test_pass_through_rule_arg_of_too_high_level(self):
    schema = _create_two_level_schema(
        level_1_category=L1_CATEGORY,
        level_2_category=L2_CATEGORY,
        pass_through_rules={
            1:
                _create_pass_through_rule(
                    category=L1_CATEGORY, arg_category=L2_CATEGORY),
        })
    with self.assertRaisesRegex(
        ValueError, 'Arg level must be smaller than function level'):
      schema.validate(inputs.GrammarOptions())

  def test_pass_through_rule_arg_of_too_low_level(self):
    schema = _create_two_level_schema(
        level_0_category=L0_CATEGORY,
        level_2_category=L2_CATEGORY,
        pass_through_rules={
            2:
                _create_pass_through_rule(
                    category=L2_CATEGORY, arg_category=L0_CATEGORY),
        })
    with self.assertRaisesRegex(
        ValueError, 'PassThroughRule arg must be from exactly one level below'):
      schema.validate(inputs.GrammarOptions())

  def test_primitive_input_sequence_empty(self):
    schema = _create_single_level_schema(level_0_category=L0_CATEGORY)
    schema.primitives.append(
        gs.PrimitiveMapping(
            input_sequence=[], output_sequence=['B'], category=L0_CATEGORY))
    with self.assertRaisesRegex(ValueError, 'Input sequence empty'):
      schema.validate(inputs.GrammarOptions())

  def test_skipped_level(self):
    schema = gs.GrammarSchema(
        primitives=[_create_primitive_mapping(category=L0_CATEGORY)],
        functions_by_level={
            2: [_create_single_arg_function_rule(arg_category=L0_CATEGORY),],
        })
    with self.assertRaisesRegex(ValueError, 'No rules found at level: 1'):
      schema.validate(inputs.GrammarOptions())

  def test_too_many_top_level_categories(self):
    schema = gs.GrammarSchema(
        primitives=[_create_primitive_mapping(category=L0_CATEGORY)],
        functions_by_level={
            1: [
                _create_single_arg_function_rule(
                    function_phrase=['f'],
                    category=L1_CATEGORY,
                    arg_category=L0_CATEGORY),
                _create_two_arg_function_rule(
                    function_phrase=['g'],
                    category='D2',
                    arg_category=L0_CATEGORY),
            ],
        })
    options = inputs.GrammarOptions(possible_categories_by_level=inputs
                                    .POSSIBLE_CATEGORIES_BY_LEVEL_8_PER_LEVEL)
    with self.assertRaisesRegex(ValueError, 'Start symbol not defined'):
      schema.validate(options)


if __name__ == '__main__':
  absltest.main()
