# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import nltk

from conceptual_learning.cscan import enums
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import rule_conversion


class InterpretationRuleStringFromProductionTest(parameterized.TestCase):

  def test_interpretation_rule_input_tokens_from_production(self):
    production = nltk_utils.production_from_production_string(
        "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]")
    self.assertEqual(
        ['x1', 'and', 'x2'],
        rule_conversion.interpretation_rule_input_tokens_from_production(
            production))

  @parameterized.named_parameters(
      ('basic', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]", 'x1 and x2'),
      ('pass_through_rule', 'A[sem=?x1] -> B[sem=?x1]', 'x1'))
  def test_interpretation_rule_input_string_from_production(
      self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    self.assertEqual(
        expected,
        rule_conversion.interpretation_rule_input_string_from_production(
            production))

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'", '[walk] = WALK'),
      ('PrimitiveMappingEmptyString', "A[sem=''] -> 'walk'",
       '[walk] = EMPTY_STRING'),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       '[x1 and x2] = [x1] [x2]'),
      ('FunctionRuleTerminalInOutput',
       "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       '[x1 and x2] = [x1] THEN [x2]'),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]', '[x1] = [x1]'),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]',
       '[x1 x2] = [x1] [x2]'))
  def test_to_interpretation_rule_string(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    converted_rule = (
        rule_conversion._interpretation_rule_string_from_production(production))
    self.assertEqual(converted_rule, expected)


class UtilityFunctionsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('no_consecutive_nonterminals', ['a', '?x1', 'b', '?x2'],
       ['a', '?x1', 'b', '?x2']), ('no_variable', ['a', 'b', 'b'], ['a b b']),
      ('one_variable', ['a', 'b', '?x1', 'c', 'd'], ['a b', '?x1', 'c d']),
      ('starts_with_variable', ['?x1', 'a', 'b'], ['?x1', 'a b']),
      ('ends_with_variable', ['a', 'b', '?x1'], ['a b', '?x1']),
      ('with_empty_strings', ['', 'b', ''], ['b']))
  def test_merge_consecutive_terminals(self, tokens, expected):
    merged = rule_conversion._merge_consecutive_terminals(tokens)
    self.assertEqual(merged, expected)

  @parameterized.named_parameters(
      ('no_variable', ['a', 'b'], ['a', 'b']),
      ('one_variable_no_repeat', ['a', '?x1', 'b'], ['a', '?x1', 'b']),
      ('one_variable_repeated_twice', ['a', '?x1', '?x1', 'b'
                                      ], ['a', '?x1 repeated twice', 'b']),
      ('one_variable_repeated_thrice', ['a', '?x1', '?x1', '?x1', 'b'
                                       ], ['a', '?x1 repeated thrice', 'b']),
      ('one_variable_repeated_4_times', ['a', '?x1', '?x1', '?x1', '?x1', 'b'
                                        ], ['a', '?x1 repeated 4 times', 'b']),
      ('two_variables_no_repeat', ['a', '?x1', '?x2', 'b'
                                  ], ['a', '?x1', '?x2', 'b']),
      ('two_variables_first_repeated', ['a', '?x1', '?x1', '?x2', 'b'],
       ['a', '?x1 repeated twice', '?x2', 'b']),
      ('two_variables_second_repeated', ['a', '?x1', '?x2', '?x2', 'b'],
       ['a', '?x1', '?x2 repeated twice', 'b']),
      ('two_variables_both_repeated', ['a', '?x1', '?x1', '?x2', '?x2', 'b'],
       ['a', '?x1 repeated twice', '?x2 repeated twice', 'b']))
  def test_process_repeated_variables(self, tokens, expected):
    processed = rule_conversion._process_repeated_variables(tokens)
    self.assertEqual(processed, expected)

  def test_natural_language_non_rule_request(self):
    request = 'original request'
    new_request = rule_conversion.natural_language_non_rule_request(request)
    expected = 'What is the interpretation of "original request"?'
    self.assertEqual(new_request, expected)

  @parameterized.named_parameters(
      ('no_variable', '[jump] = JUMP', 'jump'),
      ('variable_and_function_token', '[x1 twice] = [x1] [x1]', 'x1 twice'),
      ('multiple_variables', '[x1 and x2] = [x1] [x2]', 'x1 and x2'),
      ('multiple_variables_and_function_tokens',
       '[x1 and x2 twice] = [x1] [x2]', 'x1 and x2 twice'))
  def test_input_phrase_from_rule(self, rule, expected):
    input_phrase = rule_conversion._input_phrase_from_rule(rule)
    self.assertEqual(input_phrase, expected)

  def test_input_phrase_from_rule_should_raise_error_if_incorrect_format(self):
    rule = "C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]"
    with self.assertRaisesRegex(ValueError, 'Incorrect rule format'):
      _ = rule_conversion._input_phrase_from_rule(rule)

  @parameterized.named_parameters(
      ('no_variable', '[jump] = JUMP', '[_] = JUMP'),
      ('variable_and_function_token', '[x1 twice] = [x1] [x1]',
       '[x1 _] = [x1] [x1]'), ('multiple_variables', '[x1 and x2] = [x1] [x2]',
                               '[x1 _ x2] = [x1] [x2]'),
      ('multiple_variables_and_function_tokens',
       '[x1 and x2 twice] = [x1] [x2]', '[x1 _ x2 _] = [x1] [x2]'))
  def test_rule_pattern_from_rule(self, rule, expected):
    rule_pattern = rule_conversion.rule_pattern_from_rule(rule)
    self.assertEqual(rule_pattern, expected)


class NaturalLanguageStringFromProductionTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'",
       'The interpretation of "walk" is "WALK".'),
      ('PrimitiveMappingEmptyString', "A[sem=''] -> 'walk'",
       'The interpretation of "walk" is the empty string.'),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       ('The interpretation of a phrase x1 followed by "and" followed by a '
        'phrase x2 is the interpretation of x1 followed by the interpretation '
        'of x2.')),
      ('FunctionRuleTerminalInOutput',
       "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       ('The interpretation of a phrase x1 followed by "and" followed by a '
        'phrase x2 is the interpretation of x1 followed by "THEN" followed by '
        'the interpretation of x2.')),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]',
       'The interpretation of a phrase x1 is the interpretation of x1.'),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]',
       ('The interpretation of a phrase x1 followed by a phrase x2 is the '
        'interpretation of x1 followed by the interpretation of x2.')))
  def test_to_natural_language_string_basic(self, rule_string, expected):
    production = nltk_utils.production_from_production_string(rule_string)
    converted_rule = (
        rule_conversion._natural_language_string_from_production(production))
    self.assertEqual(converted_rule, expected)

  @parameterized.named_parameters(
      ('consecutive_terminal_rhs',
       "A[sem=(?x1)] -> B[sem=?x1] 'and' 'then' C[sem=?x2]",
       ('The interpretation of a phrase x1 followed by "and then" followed by a '
        'phrase x2 is the interpretation of x1.')),
      ('consecutive_terminal_lhs',
       'A[sem=(?x1+AND+THEN+?x2)] -> B[sem=?x1] C[sem=?x2]',
       ('The interpretation of a phrase x1 followed by a phrase x2 is the '
        'interpretation of x1 followed by "AND THEN" followed by the '
        'interpretation of x2.')),
      ('consecutive_variable_lhs',
       'A[sem=(?x1+?x1+?x2)] -> B[sem=?x1] C[sem=?x2]',
       ('The interpretation of a phrase x1 followed by a phrase x2 is the '
        'interpretation of x1 repeated twice followed by the interpretation of '
        'x2.')))
  def test_to_natural_language_string_complex(self, rule_string, expected):
    production = nltk_utils.production_from_production_string(rule_string)
    converted_rule = (
        rule_conversion._natural_language_string_from_production(production))
    self.assertEqual(converted_rule, expected)


def _get_feature_grammar():
  grammar_string = """
  % start C
  C[sem=?x1] -> S[sem=?x1]
  S[sem=(?x2+THEN+?x1)] -> V[sem=?x1] 'after' V[sem=?x2]
  S[sem=(?x1+?x2)] -> V[sem=?x1] V[sem=?x2]
  A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]
  V[sem='WALK'] -> 'walk'
  V[sem='JUMP'] -> 'jump'
  V[sem=''] -> 'wait'
  """
  return nltk.grammar.FeatureGrammar.fromstring(grammar_string)


class RuleMappingFromGrammarTest(absltest.TestCase):

  def test_feature_grammar_production_rule_format(self):
    grammar = _get_feature_grammar()
    rule_format = enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION
    mapping = rule_conversion.rule_mapping_from_grammar(grammar, rule_format)

    expected = {
        str(production): [production] for production in grammar.productions()
    }
    self.assertEqual(mapping, expected)

  def test_interpretation_rule_format(self):
    grammar = _get_feature_grammar()
    rule_format = enums.RuleFormat.INTERPRETATION_RULE
    mapping = rule_conversion.rule_mapping_from_grammar(grammar, rule_format)
    expected = {
        '[x1 after x2] = [x2] THEN [x1]': [
            nltk_utils.production_from_production_string(
                "S[sem=(?x2+THEN+?x1)] -> V[sem=?x1] 'after' V[sem=?x2]")
        ],
        '[x1 x2] = [x1] [x2]': [
            nltk_utils.production_from_production_string(
                'S[sem=(?x1+?x2)] -> V[sem=?x1] V[sem=?x2]'),
            nltk_utils.production_from_production_string(
                'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]')
        ],
        '[walk] = WALK': [
            nltk_utils.production_from_production_string(
                "V[sem='WALK'] -> 'walk'")
        ],
        '[jump] = JUMP': [
            nltk_utils.production_from_production_string(
                "V[sem='JUMP'] -> 'jump'")
        ],
        '[wait] = EMPTY_STRING': [
            nltk_utils.production_from_production_string("V[sem=''] -> 'wait'")
        ],
    }
    self.assertEqual(mapping, expected)


class RuleFromProductionTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('PrimitiveMapping', "A[sem='WALK'] -> 'walk'", '[walk] = WALK'),
      ('PrimitiveMappingEmptyString', "A[sem=''] -> 'walk'",
       '[walk] = EMPTY_STRING'),
      ('FunctionRule', "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       '[x1 and x2] = [x1] [x2]'),
      ('FunctionRuleTerminalInOutput',
       "A[sem=(?x1+THEN+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]",
       '[x1 and x2] = [x1] THEN [x2]'),
      ('PassThroughRule', 'A[sem=?x1] -> B[sem=?x1]', None),
      ('ConcatRule', 'A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]',
       '[x1 x2] = [x1] [x2]'))
  def test_interpretation_rule_format(self, production_string, expected):
    rule_format = enums.RuleFormat.INTERPRETATION_RULE
    production = nltk_utils.production_from_production_string(production_string)
    rule = rule_conversion.rule_from_production(production, rule_format)

    self.assertEqual(rule, expected)

  @parameterized.named_parameters(
      ('with_empty_string', 'S[sem=(?x1+?x2)] -> V[sem=?x1] V[sem=?x2]',
       "V[sem=''] -> 'wait'", 0, '[wait x1] = [x1]'),
      ('with_empty_string_and_other_string',
       "S[sem=(?x2+THEN+?x1)] -> V[sem=?x1] 'after' V[sem=?x2]",
       "V[sem=''] -> 'wait'", 2, '[x1 after wait] = THEN [x1]'),
      ('only_empty_string', "S[sem=(?x1+)] -> V[sem=?x1] 'wait'",
       "V[sem=''] -> 'wait'", 0, '[wait wait] = EMPTY_STRING'),
      ('only_empty_string_with_leading_empty_string',
       "S[sem=(''+?x1)] -> 'wait' V[sem=?x1]", "V[sem=''] -> 'wait'", 1,
       '[wait wait] = EMPTY_STRING'),
  )
  def test_should_not_show_empty_string_token_unless_empty_sequence(
      self, parent_string, other_parent_string, index, expected):
    rule_format = enums.RuleFormat.INTERPRETATION_RULE
    parent = nltk_utils.production_from_production_string(parent_string)
    other_parent = nltk_utils.production_from_production_string(
        other_parent_string)
    production = production_composition.compose(parent, other_parent, index)
    rule = rule_conversion.rule_from_production(production, rule_format)

    self.assertEqual(rule, expected)


if __name__ == '__main__':
  absltest.main()
