# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import nltk

from conceptual_learning.cscan import nltk_utils


class TokenExtractionFunctionsTest(parameterized.TestCase):

  def test_extract_lhs_tokens(self):
    production_string = 'A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] C[sem=?x2]'
    production = nltk_utils.production_from_production_string(production_string)
    extracted = nltk_utils.extract_lhs_tokens(production)
    expected = ['?x1', 'THEN', '?x2']
    self.assertEqual(extracted, expected)

  @parameterized.named_parameters(
      ('unknown_sem_feature', 'A[sem={?x1}] -> B[sem=?x1]'),
      ('unknown_nested_sem_feature',
       'A[sem=(?x1+THEN+(?x1,NESTED))] -> B[sem=?x1]'))
  def test_extract_lhs_tokens_should_raise_value_error(self, production_string):
    production = nltk_utils.production_from_production_string(production_string)
    with self.assertRaisesRegex(ValueError, 'Failed to extract item from'):
      nltk_utils.extract_lhs_tokens(production)

  def test_extract_rhs_tokens(self):
    production_string = "A[sem=?x1] -> B[sem=?x1] 'and' C[sem=?x2]"
    production = nltk_utils.production_from_production_string(production_string)
    extracted = nltk_utils.extract_rhs_tokens(production)
    expected = ['?x1', 'and', '?x2']
    self.assertEqual(extracted, expected)

  def test_extract_rhs_tokens_should_raise_value_error(self):
    production_string = "A[sem=?x1] -> B[sem=?x1] C[sem='WALK']"
    production = nltk_utils.production_from_production_string(production_string)
    with self.assertRaisesRegex(
        ValueError,
        'Nonterminal term on the RHS should have only nltk.Variable as '
        'semantics, but got'):
      nltk_utils.extract_rhs_tokens(production)


class GrammarToStringTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('with_sem_features', ('% start C[sem=?x1]\n'
                             'C[sem=?x1] -> S[sem=?x1]')),
      ('without_sem_features',
       ('% start  C\n'
        'C[sem=?x1] -> S[sem=?x1]\n'
        "S[sem=(?x1+?x2)] -> U[sem=?x1] 'and' U[sem=?x2]")),
      ('without_start_state', ("C -> S 'and' S\n"
                               "S -> U 'twice'\n"
                               'U -> JUMP')))
  def test_grammar_to_string_roundtrip(self, original_grammar_string):
    grammar = nltk.grammar.FeatureGrammar.fromstring(original_grammar_string)
    grammar_string = nltk_utils.grammar_to_string(grammar)
    restored_grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
    restored_start = restored_grammar.start()
    original_start = grammar.start()
    with self.subTest('restored_start_state_should_match_original'):
      self.assertEqual(restored_start, original_start)

    original_productions = grammar.productions()
    restored_productions = restored_grammar.productions()
    with self.subTest('restored_productions_should_match_original'):
      self.assertCountEqual(original_productions, restored_productions)


class ProductionFromProductionStringTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'"),
      ('PrimitiveMappingMultipleTokens',
       "A[sem=(LTURN, WALK)] -> 'walk' 'left'"),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]"),
      ('FunctionRuleTerminalInOutput',
       "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]"),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]'),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]'))
  def test_from_standard_production_string(self, production_string):
    production = nltk_utils.production_from_production_string(production_string)
    recovered_production_string = str(production)
    self.assertEqual(production_string, recovered_production_string)

  @parameterized.named_parameters(
      ('PrimitiveMapping_without_quotes', "A[sem=WALK] -> 'walk'",
       "A[sem='WALK'] -> 'walk'"),
      ('PrimitiveMappingMultipleTokens_with_quotes_and_plus',
       "A[sem=('LTURN'+'WALK')] -> 'walk' 'left'",
       "A[sem=(LTURN, WALK)] -> 'walk' 'left'"),
      ('FunctionRuleTerminalInOutput_with_quotes',
       "A[sem=(?x1+'THEN'+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]"))
  def test_from_non_standard_production_string(self, production_string,
                                               standardized_string):
    # These string formats are also accepted, although they differ from the
    # format used when the production is automatically converted to string.
    production = nltk_utils.production_from_production_string(production_string)
    recovered_production_string = str(production)
    self.assertEqual(standardized_string, recovered_production_string)


class IsRuleTypeTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'", True),
      ('PrimitiveMappingMultipleTokens',
       "A[sem=('LTURN'+'WALK')] -> 'walk' 'left'", True),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       False), ('FunctionRuleTerminalInOutput',
                "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]", False),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]', False),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]', False))
  def test_is_primitive_mapping(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    result = nltk_utils.is_primitive_mapping(production)
    self.assertEqual(result, expected)

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'", False),
      ('PrimitiveMappingMultipleTokens',
       "A[sem=('LTURN'+'WALK')] -> 'walk' 'left'", False),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       False), ('FunctionRuleTerminalInOutput',
                "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]", False),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]', True),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]', False))
  def test_is_pass_through_rule(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    result = nltk_utils.is_pass_through_rule(production)
    self.assertEqual(result, expected)

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'", False),
      ('PrimitiveMappingMultipleTokens',
       "A[sem=('LTURN'+'WALK')] -> 'walk' 'left'", False),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]", True),
      ('FunctionRuleTerminalInOutput',
       "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]", True),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]', False),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]', False))
  def test_is_function_rule(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    result = nltk_utils.is_function_rule(production)
    self.assertEqual(result, expected)

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'", False),
      ('PrimitiveMappingMultipleTokens',
       "A[sem=('LTURN'+'WALK')] -> 'walk' 'left'", False),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       False), ('FunctionRuleTerminalInOutput',
                "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]", False),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]', False),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]', True))
  def test_is_concat_rule(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    result = nltk_utils.is_concat_rule(production)
    self.assertEqual(result, expected)


class VariableNamingConventionTest(parameterized.TestCase):

  def test_add_variable_prefix(self):
    self.assertEqual('?x1', nltk_utils.add_variable_prefix('x1'))

  @parameterized.named_parameters(
      ('removes_prefix_from_variable_token', '?x1', 'x1'),
      ('noop_on_non_variable_token', 'JUMP', 'JUMP'),
      ('noop_on_empty_string', '', ''),
  )
  def test_strip_variable_prefix(self, token, expected_result):
    self.assertEqual(expected_result, nltk_utils.strip_variable_prefix(token))

  @parameterized.named_parameters(
      ('variable_token', '?x1', True),
      ('non_variable_token', 'JUMP', False),
      ('empty_string', '', False),
  )
  def test_is_output_token_a_variable(self, token, expected_result):
    with self.subTest('returns_correct_result_initially'):
      self.assertEqual(expected_result,
                       nltk_utils.is_output_token_a_variable(token))

    with self.subTest('always_returns_true_after_adding_variable_prefix'):
      self.assertTrue(
          nltk_utils.is_output_token_a_variable(
              nltk_utils.add_variable_prefix(token)))

    with self.subTest('always_returns_false_after_stripping_variable_prefix'):
      self.assertFalse(
          nltk_utils.is_output_token_a_variable(
              nltk_utils.strip_variable_prefix(token)))

  @parameterized.named_parameters(
      ('only_terminals', 'A[sem=(WALK+JUMP+WALK)] -> B', 'o1 o2 o1'),
      ('only_variables', 'A[sem=(?x1+?x2+?x1)] -> B', 'x1 x2 x1'),
      ('terminals_and_variables',
       'A[sem=(?x3+LTURN+?x3+WALK+WALK+LTURN+?x1)] -> B',
       'x1 o1 x1 o2 o2 o1 x2'))
  def test_output_pattern_from_production(self, production_string,
                                          expected_output_pattern):
    production = nltk_utils.production_from_production_string(production_string)
    output_pattern = nltk_utils.output_pattern_from_production(production)
    self.assertEqual(expected_output_pattern, output_pattern)


if __name__ == '__main__':
  absltest.main()
