# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import nltk

from conceptual_learning.cscan import grammar_loader
from conceptual_learning.cscan import grammar_representation
from conceptual_learning.cscan import grammar_schema as gs
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import test_utils


class GrammarRepresentationTest(parameterized.TestCase):

  def test_category_from_nonterminal(self):
    production_string = "A[sem=(?x1+?x1)] -> B[sem=?x1] 'twice'"
    production = nltk_utils.production_from_production_string(production_string)
    converted = grammar_representation._category_from_nonterminal(
        production.lhs())
    expected = 'A'
    self.assertEqual(converted, expected)

  def test_function_arg_from_nonterminal(self):
    production_string = "A[sem=(?x1+?x1)] -> B[sem=?x1] 'twice'"
    production = nltk_utils.production_from_production_string(production_string)
    converted = grammar_representation._function_arg_from_nonterminal(
        production.rhs()[0])
    expected = gs.FunctionArg(variable='?x1', category='B')
    self.assertEqual(converted, expected)

  @parameterized.named_parameters(
      ('variable_concat', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       ['?x1', '?x2']), ('string', "A[sem='WALK'] -> 'walk'", ['WALK']),
      ('string_concat', "A[sem=('LTURN'+'WALK')] -> 'walk' 'left'",
       ['LTURN', 'WALK']),
      ('string_and_variable', "A[sem=('LTURN'+?x1)] -> B[sem=?x1] 'left'",
       ['LTURN', '?x1']))
  def test_output_sequence_from_nonterminal(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    converted = grammar_representation._output_sequence_from_nonterminal(
        production.lhs())
    self.assertEqual(converted, expected)

  @parameterized.named_parameters(
      ('basic', "A[sem='WALK'] -> 'walk'",
       gs.PrimitiveMapping(
           input_sequence=['walk'], output_sequence=['WALK'], category='A')),
      ('string_concat', "A[sem=('LTURN'+'WALK')] -> 'walk' 'left'",
       gs.PrimitiveMapping(
           input_sequence=['walk', 'left'],
           output_sequence=['LTURN', 'WALK'],
           category='A')))
  def test_primitive_mapping_from_production(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    converted = grammar_representation._primitive_mapping_from_production(
        production)
    self.assertEqual(converted, expected)

  def test_pass_through_rule_from_production(self):
    production_string = 'A[sem=?x1] -> B[sem=?x1]'
    production = nltk_utils.production_from_production_string(production_string)
    converted = grammar_representation._pass_through_rule_from_production(
        production)
    expected = gs.PassThroughRule(
        category='A', arg=gs.FunctionArg(variable='?x1', category='B'))
    self.assertEqual(converted, expected)

  def test_function_rule_from_production(self):
    production_string = "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]"
    production = nltk_utils.production_from_production_string(production_string)
    converted = grammar_representation._function_rule_from_production(
        production)
    expected = gs.FunctionRule(
        function_phrase=['and'],
        category='A',
        num_args=2,
        num_postfix_args=1,
        args=[
            gs.FunctionArg(variable='?x1', category='B'),
            gs.FunctionArg(variable='?x2', category='C')
        ],
        output_sequence=['?x1', '?x2'])
    self.assertEqual(converted, expected)

  def test_concat_rule_from_production(self):
    production_string = 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]'
    production = nltk_utils.production_from_production_string(production_string)
    converted = grammar_representation._concat_rule_from_production(production)
    expected = gs.ConcatRule(
        category='A',
        arg1=gs.FunctionArg(variable='?x1', category='B'),
        arg2=gs.FunctionArg(variable='?x2', category='C'),
        output_sequence=['?x1', '?x2'])
    self.assertEqual(converted, expected)

  def test_level_by_category_from_feature_grammar(self):
    grammar = grammar_loader.load_standard_grammar(
        grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)
    level_by_category = (
        grammar_representation._level_by_category_from_feature_grammar(grammar))
    expected = {'D': 1, 'V': 2, 'S': 3, 'C': 4}
    self.assertEqual(level_by_category, expected)


class GrammarSchemaFromFeatureGrammarTest(parameterized.TestCase):

  def test_converts_scan_finite_nye_standardized(self):
    grammar = grammar_loader.load_standard_grammar(
        grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)
    result = grammar_representation.grammar_schema_from_feature_grammar(grammar)
    expected = test_utils.get_grammar_schema_for_scan_finite_nye_standardized()
    self.assertEqual(result, expected)

  @parameterized.named_parameters(
      ('pass_through_rule_different_variables', 'A[sem=?x1] -> B[sem=?x2]'),
      ('non_sem_features_rhs',
       "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2,non_sem=ignored]"),
      ('non_sem_features_lhs',
       "A[sem=(?x1+?x2),non_sem=ignored] -> B[sem=?x1] 'and' C[sem=?x2]"))
  def test_raises_error_if_sem_feature_not_representable_as_grammar_schema(
      self, feature_grammar_string):
    feature_grammar = nltk.grammar.FeatureGrammar.fromstring(
        feature_grammar_string)
    with self.assertRaisesRegex(
        ValueError, 'Grammars do not have the same production strings.'):
      _ = grammar_representation.grammar_schema_from_feature_grammar(
          feature_grammar)

  @parameterized.named_parameters(
      ('unsupported_function_rule',
       "A[sem=?x1] -> 'do' B[sem=?x1] 'carefully'"),
      ('unsupported_concat_rule',
       'A[sem=(?x1+?x2+?x3)] -> B[sem=?x1] C[sem=?x2] D[sem=?x3]'))
  def test_raises_error_if_phrase_structure_not_representable_as_grammar_schema(
      self, feature_grammar_string):
    feature_grammar = nltk.grammar.FeatureGrammar.fromstring(
        feature_grammar_string)
    with self.assertRaisesRegex(ValueError, 'Failed to convert production:'):
      _ = grammar_representation.grammar_schema_from_feature_grammar(
          feature_grammar)


if __name__ == '__main__':
  absltest.main()
