# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import nltk
import numpy as np

from conceptual_learning.cscan import distractor_generation
from conceptual_learning.cscan import grammar_generation
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition


def _get_feature_grammar(non_variable_token=None):
  if non_variable_token:
    grammar_string = f"C[sem=(?x+'{non_variable_token}')] -> D[sem=?x]"
  else:
    grammar_string = "C[sem=?x] -> D[sem=?x]"
  return nltk.grammar.FeatureGrammar.fromstring(grammar_string)


class ProductionEditTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ("no_variable", "A[sem=(WALK+JUMP)] -> B", "A[sem=(JUMP+WALK)] -> B"),
      ("some_variables", "A[sem=(WALK+?x1)] -> B", "A[sem=(?x1+WALK)] -> B"),
      ("all_variables", "A[sem=(?x2+?x1)] -> B", "A[sem=(?x1+?x2)] -> B"))
  def test_swap(self, production_string, expected_production_string):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    edited = distractor_generation._swap(production, grammar, self.rng)
    expected = nltk_utils.production_from_production_string(
        expected_production_string)
    self.assertEqual(edited, expected)

  def test_swap_should_raise_error_if_validation_fails(self):
    # The production is valid for _swap only if there are at least two different
    # tokens.
    production_string = "A[sem=(WALK+WALK)] -> B"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._swap(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_no_variable", "A[sem=WALK] -> B",
       ["A[sem=(NEW+WALK)] -> B", "A[sem=(WALK+NEW)] -> B"]),
      ("lhs_has_variable", "A[sem=?x1] -> B",
       ["A[sem=(?x1+NEW)] -> B", "A[sem=(NEW+?x1)] -> B"]))
  def test_add_non_variable(self, production_string,
                            possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar(non_variable_token="NEW")
    edited = distractor_generation._add_non_variable(production, grammar,
                                                     self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_add_non_variable_should_raise_error_if_validation_fails(self):
    production_string = "A[sem=WALK] -> B"
    production = nltk_utils.production_from_production_string(production_string)
    # The grammar is valid for _add_non_variable if it contains at least one
    # non-variable token.
    grammar_string = "C[sem=?x] -> D[sem=?x]"
    grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._add_non_variable(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_no_variable", "A[sem=WALK] -> B[sem=?x1]",
       ["A[sem=(?x1+WALK)] -> B[sem=?x1]", "A[sem=(WALK+?x1)] -> B[sem=?x1]"]),
      ("lhs_has_variable", "A[sem=?x1] -> B[sem=?x1]", [
          "A[sem=(?x1+?x1)] -> B[sem=?x1]",
      ]))
  def test_add_variable(self, production_string, possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    edited = distractor_generation._add_variable(production, grammar, self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_add_variable_should_raise_error_if_validation_fails(self):
    # The production is valid for _add_variable if its RHS contains at least one
    # variable.
    production_string = "A[sem=WALK] -> 'walk'"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._add_variable(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_no_variable", "A[sem=(WALK+JUMP)] -> B",
       ["A[sem=WALK] -> B", "A[sem=JUMP] -> B"]),
      ("lhs_has_variable", "A[sem=(?x1+WALK)] -> B", [
          "A[sem=?x1] -> B",
      ]))
  def test_remove_non_variable(self, production_string,
                               possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    edited = distractor_generation._remove_non_variable(production, grammar,
                                                        self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_remove_non_variable_should_raise_error_if_validation_fails(self):
    # The production is valid for _remove_non_variable if its LHS contains at
    # least one non-variable.
    production_string = "A[sem=?x1] -> B"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._remove_non_variable(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_two_variables", "A[sem=(?x1+?x2)] -> B",
       ["A[sem=?x1] -> B", "A[sem=?x2] -> B"]),
      ("lhs_one_variable", "A[sem=(?x1+WALK)] -> B", [
          "A[sem=WALK] -> B",
      ]))
  def test_remove_variable(self, production_string,
                           possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    edited = distractor_generation._remove_variable(production, grammar,
                                                    self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_remove_variable_should_raise_error_if_validation_fails(self):
    # The production is valid for _remove_variable if its LHS contains at least
    # one variable.
    production_string = "A[sem=WALK] -> 'walk'"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._remove_variable(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_no_variable", "A[sem=(WALK+JUMP)] -> B",
       ["A[sem=(WALK+NEW)] -> B", "A[sem=(NEW+JUMP)] -> B"]),
      ("lhs_has_variable", "A[sem=(?x1+WALK)] -> B", [
          "A[sem=(?x1+NEW)] -> B",
      ]))
  def test_replace_non_variable(self, production_string,
                                possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar(non_variable_token="NEW")
    edited = distractor_generation._replace_non_variable(
        production, grammar, self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_replace_non_variable_should_raise_error_if_validation_fails(self):
    # The production is valid for _replace_non_variable if its LHS contains at
    # least one non-variable.
    production_string = "A[sem=?x1] -> B"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar(non_variable_token="NEW")
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._replace_non_variable(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_two_variables", "A[sem=(?x1+?x2)] -> B[sem=?x2]",
       ["A[sem=(?x2+?x2)] -> B[sem=?x2]"]),
      ("lhs_one_variable", "A[sem=(?x1+WALK)] -> B[sem=?x2]", [
          "A[sem=(?x2+WALK)] -> B[sem=?x2]",
      ]))
  def test_replace_variable(self, production_string,
                            possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    edited = distractor_generation._replace_variable(production, grammar,
                                                     self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_replace_variable_should_raise_error_if_validation_fails(self):
    # The production is valid for _replace_variable if a non-identity
    # edit could be made.
    production_string = "A[sem=?x1] -> B[sem=?x1]"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._replace_variable(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_no_variable", "A[sem=(WALK+JUMP)] -> B",
       ["A[sem=(WALK+WALK+JUMP)] -> B", "A[sem=(WALK+JUMP+JUMP)] -> B"]),
      ("lhs_has_variable", "A[sem=(?x1+WALK)] -> B", [
          "A[sem=(?x1+WALK+WALK)] -> B",
      ]))
  def test_repeat_non_variable(self, production_string,
                               possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    edited = distractor_generation._repeat_non_variable(production, grammar,
                                                        self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_repeat_non_variable_should_raise_error_if_validation_fails(self):
    # The production is valid for _repeat_non_variable if its LHS contains at
    # least one non-variable.
    production_string = "A[sem=?x1] -> B"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._repeat_non_variable(production, grammar, self.rng)

  @parameterized.named_parameters(
      ("lhs_two_variables", "A[sem=(?x1+?x2)] -> B",
       ["A[sem=(?x1+?x1+?x2)] -> B", "A[sem=(?x1+?x2+?x2)] -> B"]),
      ("lhs_one_variable", "A[sem=(?x1+WALK)] -> B", [
          "A[sem=(?x1+?x1+WALK)] -> B",
      ]))
  def test_repeat_variable(self, production_string,
                           possible_production_strings):
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    edited = distractor_generation._repeat_variable(production, grammar,
                                                    self.rng)
    possible_productions = [
        nltk_utils.production_from_production_string(possible_production_string)
        for possible_production_string in possible_production_strings
    ]
    self.assertIn(edited, possible_productions)

  def test_repeat_variable_should_raise_error_if_validation_fails(self):
    # The production is valid for _repeat_variable if its LHS contains at least
    # one variable.
    production_string = "A[sem=WALK] -> 'walk'"
    production = nltk_utils.production_from_production_string(production_string)
    grammar = _get_feature_grammar()
    with self.assertRaises(distractor_generation._FailedToApplyEditError):
      distractor_generation._repeat_variable(production, grammar, self.rng)


class EmptyStringTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ("repeat_empty_string", "U[sem=UNUSED] -> 'run'", "U[sem=''] -> 'turn'",
       distractor_generation._repeat_non_variable, "U[sem=''] -> 'turn'"),
      ("add_empty_string", "U[sem=''] -> 'run'", "U[sem=JUMP] -> 'turn'",
       distractor_generation._add_non_variable, "U[sem=JUMP] -> 'turn'"),
      ("add_nonempty_string", "U[sem=JUMP] -> 'run'", "U[sem=''] -> 'turn'",
       distractor_generation._add_non_variable, "U[sem=JUMP] -> 'turn'"))
  def test_apply_edit_with_empty_string(self, grammar_string,
                                        original_production_string, edit,
                                        expected_production_string):
    grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)

    original_production = nltk_utils.production_from_production_string(
        original_production_string)
    expected_production = nltk_utils.production_from_production_string(
        expected_production_string)

    edited_production = distractor_generation._apply_edit(
        original_production, grammar, self.rng, edit)

    self.assertEqual(edited_production, expected_production)


class UtilityFunctionsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ("primitive_mapping", "U[sem='JUMP'] -> 'jump'", "'jump'"),
      ("concat_rule", "D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]",
       "U[sem=?x1] W[sem=?x2]"),
      ("function_rule", "C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]",
       "S[sem=?x1] 'and' S[sem=?x2]"))
  def test_rhs_string_from_production(self, production_string,
                                      expected_rhs_string):
    production = nltk_utils.production_from_production_string(production_string)
    rhs_string = distractor_generation._rhs_string_from_production(production)
    self.assertEqual(expected_rhs_string, rhs_string)


def _get_provenance_by_production_for_grammar(
    grammar
):
  provenance_by_production = production_composition.ProductionProvenanceDict()
  for production in grammar.productions():
    provenance_by_production[production] = (
        production_composition.ProductionProvenance(source=production))
  return provenance_by_production


class AlternativeGrammarStrategyTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  def test_create_distractor_production_with_alternative_grammar(self):
    grammar_generator = grammar_generation.GrammarGenerator()
    grammar = grammar_generator.generate_grammar()
    provenance_by_production = _get_provenance_by_production_for_grammar(
        grammar)
    # We try to generate a distractor for every production (except for the
    # pass-through rules) in the grammar to make sure that a distractor can be
    # generated for every rule type.
    for production in grammar.productions():
      if nltk_utils.is_pass_through_rule(production):
        continue
      distractor_creation_result = (
          distractor_generation
          .create_distractor_production_with_alternative_grammar(
              production,
              grammar,
              self.rng,
              grammar_generator,
              provenance_by_production,
              max_attempts_per_negative_example=20))

      distractor_production = distractor_creation_result.distractor

      with self.subTest("distractor_should_be_different"):
        self.assertNotEqual(production, distractor_production)

      with self.subTest("distractor_should_have_the_same_rhs"):
        self.assertEqual(production.rhs(), distractor_production.rhs())

      with self.subTest("distractor_should_have_the_same_lhs_category"):
        self.assertEqual(production.lhs()[nltk.grammar.TYPE],
                         distractor_production.lhs()[nltk.grammar.TYPE])


if __name__ == "__main__":
  absltest.main()
