# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import nltk

from conceptual_learning.cscan import inference
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition


class ProductionCompositionUtilitiesTest(parameterized.TestCase):

  def test_get_variables(self):
    production_string = "C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]"
    production = nltk_utils.production_from_production_string(production_string)
    variables = production_composition._get_variables(production)
    expected = [nltk.Variable("?x1"), nltk.Variable("?x2")]
    self.assertEqual(variables, expected)

  @parameterized.named_parameters(
      ("no_variable", "A[sem=WALK] -> 'walk'", 0),
      ("one_variable", "A[sem=WALK] -> B[sem=?x1] 'walk'", 1),
      ("repeated_variables",
       "A[sem=WALK] -> B[sem=?x1] C[sem=?x1] D[sem=?x2] 'walk'", 2))
  def test_num_variables_from_production(self, production_string, expected):
    production = nltk_utils.production_from_production_string(production_string)
    num_variables = production_composition.num_variables_from_production(
        production)
    self.assertEqual(num_variables, expected)

  def test_normalize_variable_names(self):
    production_string = ("A[sem=(?u1+?u+?v3)] -> B[sem=?u] 'and' C[sem=?v3] "
                         "'after' D[sem=?u1]")
    production = nltk_utils.production_from_production_string(production_string)
    normalized = production_composition._normalize_variable_names(
        production, token="z")
    expected_production_string = ("A[sem=(?z3+?z1+?z2)] -> B[sem=?z1] 'and' "
                                  "C[sem=?z2] 'after' D[sem=?z3]")
    self.assertEqual(str(normalized), expected_production_string)

  def test_normalize_variable_names_should_raise_error_if_token_exists(self):
    production_string = ("A[sem=(?u1+?u+?v3)] -> B[sem=?u] 'and' C[sem=?v3] "
                         "'after' D[sem=?u1]")
    production = nltk_utils.production_from_production_string(production_string)
    with self.assertRaisesRegex(
        ValueError, r"Variable name \w+ already exists in production"):
      production_composition._normalize_variable_names(production, token="v")

  @parameterized.named_parameters(
      ("string_and_variable", "A[sem=(?x1+(WALK, WALK))] -> B[sem=?x1]",
       "A[sem=(?x1+WALK+WALK)] -> B[sem=?x1]"),
      ("string_feature_tuple", "A[sem=(WALK, WALK)] -> B[sem=?x1]",
       "A[sem=(WALK, WALK)] -> B[sem=?x1]"),
      ("string_feature_concat", "A[sem=(WALK+WALK)] -> B[sem=?x1]",
       "A[sem=(WALK, WALK)] -> B[sem=?x1]"),
  )
  def test_normalize_semantics(self, production_string,
                               expected_production_string):
    production = nltk_utils.production_from_production_string(production_string)
    normalized = production_composition.normalize_semantics(production)
    self.assertEqual(str(normalized), expected_production_string)

  @parameterized.named_parameters(
      ("sem_feature_wrong_type", "A[sem={?x1,(WALK, WALK)}] -> B[sem=?x1]",
       ("Semantic feature of LHS must be a Variable, FeatureValueTuple, or "
        "FeatureValueConcat")),
      ("sem_item_wrong_type", "A[sem=(?x1+{WALK, WALK})] -> B[sem=?x1]",
       ("Item in LHS semantic feature should be a Variable or a "
        "FeatureValueTuple")))
  def test_normalize_semantics_should_raise_value_error(self, production_string,
                                                        expected_regex):
    production = nltk_utils.production_from_production_string(production_string)
    with self.assertRaisesRegex(ValueError, expected_regex):
      production_composition.normalize_semantics(production)

  @parameterized.named_parameters(
      ("composable", "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]",
       "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'", [0, 2]),
      ("not_composable", "C[sem=(?x2+?x1)] -> X[sem=?x1] 'after' X[sem=?x2]",
       "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'", []),
  )
  def test_composable_indices(self, parent_string, other_parent_string,
                              expected):
    parent = nltk_utils.production_from_production_string(parent_string)
    other_parent = nltk_utils.production_from_production_string(
        other_parent_string)
    indices = production_composition.composable_indices(parent, other_parent)
    self.assertEqual(indices, expected)


class ProductionCompositionTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ("composable_at_0", "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]",
       "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'", 0,
       "C[sem=(?x2+?x1+?x1)] -> V[sem=?x1] 'twice' 'after' S[sem=?x2]"),
      ("composable_at_2", "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]",
       "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'", 2,
       "C[sem=(?x2+?x2+?x1)] -> S[sem=?x1] 'after' V[sem=?x2] 'twice'"),
  )
  def test_compose(self, parent_string, other_parent_string, index,
                   expected_production_string):
    parent = nltk_utils.production_from_production_string(parent_string)
    other_parent = nltk_utils.production_from_production_string(
        other_parent_string)
    composed = production_composition.compose(parent, other_parent, index)
    self.assertEqual(str(composed), expected_production_string)

  def test_compose_should_raise_value_error_if_not_composable_at_index(self):
    parent_string = "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]"
    other_parent_string = "S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'"
    index = 1
    parent = nltk_utils.production_from_production_string(parent_string)
    other_parent = nltk_utils.production_from_production_string(
        other_parent_string)
    with self.assertRaisesRegex(ValueError,
                                "Productions not composable at index"):
      production_composition.compose(parent, other_parent, index)

  @parameterized.named_parameters(
      ("only_empty_string", "A[sem=?x1] -> B[sem=?x1]",
       "B[sem=''] -> 'nothing'", 0, "A[sem=''] -> 'nothing'"),
      ("empty_string_and_NONEMPTY_string",
       "A[sem=(?x1+NONEMPTY)] -> B[sem=?x1]", "B[sem=''] -> 'nothing'", 0,
       "A[sem='NONEMPTY'] -> 'nothing'"),
      ("empty_string_and_variable", "A[sem=(?x1+?x2)] -> B[sem=?x1] C[sem=?x2]",
       "B[sem=''] -> 'nothing'", 0, "A[sem=?x1] -> 'nothing' C[sem=?x1]"),
      ("empty_string_and_NONEMPTY_string_and_variable",
       "A[sem=(?x1+NONEMPTY+?x2)] -> B[sem=?x1] C[sem=?x2]",
       "B[sem=''] -> 'nothing'", 0,
       "A[sem=(NONEMPTY+?x1)] -> 'nothing' C[sem=?x1]"))
  def test_compose_should_handle_empty_strings_correctly(
      self, parent_production_string, other_parent_production_string, index,
      expected_composed_production_string):
    parent = nltk_utils.production_from_production_string(
        parent_production_string)
    other_parent = nltk_utils.production_from_production_string(
        other_parent_production_string)
    composed_production = production_composition.compose(
        parent, other_parent, index)

    self.assertEqual(
        str(composed_production), expected_composed_production_string)


class ProductionProvenanceDictTest(absltest.TestCase):

  def test_should_raise_error_if_production_and_provenance_mismatch(self):
    production_string_0 = (
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "U[sem='WALK'] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    index = 0
    provenance = production_composition.ProductionProvenance(
        source=production_0, compositions=((production_1, index),))

    mismatched_production_string = (
        "V[sem=(?x1+?x1+RUN)] -> 'walk' 'opposite' W[sem=?x1]")
    mismatched_production = nltk_utils.production_from_production_string(
        mismatched_production_string)

    production_provenance_dict = (
        production_composition.ProductionProvenanceDict())

    with self.assertRaisesRegex(ValueError,
                                "Provenance and production mismatch."):
      production_provenance_dict[mismatched_production] = provenance


class ProductionProvenanceTest(parameterized.TestCase):

  def test_instantiation_should_calculate_production(self):
    production_string_0 = (
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "U[sem='WALK'] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    index = 0

    composed_production_string = (
        "V[sem=(?x1+?x1+WALK)] -> 'walk' 'opposite' W[sem=?x1]")
    composed_production = nltk_utils.production_from_production_string(
        composed_production_string)

    provenance = production_composition.ProductionProvenance(
        source=production_0, compositions=((production_1, index),))

    self.assertEqual(provenance.get_production(), composed_production)

  def test_accessing_production_should_raise_error_if_not_composable(self):
    production_string_0 = (
        "V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]")
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "U[sem='WALK'] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)

    # The two productions are not composable at index 1.
    index = 1

    # Creating the provenance does not raise error.
    invalid_provenance = production_composition.ProductionProvenance(
        source=production_0, compositions=((production_1, index),))

    with self.assertRaisesRegex(ValueError,
                                "Productions not composable at index"):
      _ = invalid_provenance.get_production()

  def test_calling_compose_should_record_provenance(self):
    # We construct this test case in a way so that the shift of indices in the
    # call to the splice method is tested.
    production_string_0 = ("C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]")
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    production_string_2 = "S[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]"
    production_2 = nltk_utils.production_from_production_string(
        production_string_2)
    production_string_3 = "U[sem=JUMP] -> 'jump'"
    production_3 = nltk_utils.production_from_production_string(
        production_string_3)
    production_string_4 = "W[sem=RTURN] -> 'right'"
    production_4 = nltk_utils.production_from_production_string(
        production_string_4)

    provenance_by_production = production_composition.ProductionProvenanceDict()

    # C[sem=(WALK+?x2)] -> 'walk' 'and' S[sem=?x2]
    production_01 = production_composition.compose(production_0, production_1,
                                                   0, provenance_by_production)

    # S[sem=(?x2+JUMP)] -> 'jump' W[sem=?x2]
    production_23 = production_composition.compose(production_2, production_3,
                                                   0, provenance_by_production)
    # S[sem=(RTURN+JUMP)] -> 'jump' 'right'
    production_234 = production_composition.compose(production_23, production_4,
                                                    1, provenance_by_production)

    # C[sem=(WALK+RTURN+JUMP)] -> 'walk' 'and' 'jump' 'right'
    composed_production = production_composition.compose(
        production_01, production_234, 2, provenance_by_production)

    source_productions = provenance_by_production.get_source_productions()
    expected_source_productions = [
        production_0, production_1, production_2, production_3, production_4
    ]
    with self.subTest("should_record_correct_source_productions"):
      self.assertCountEqual(source_productions, expected_source_productions)

    with self.subTest(
        "should_record_parents_productions_as_their_own_provenances"):
      self.assertEqual(
          provenance_by_production[production_0],
          production_composition.ProductionProvenance(source=production_0))
      self.assertEqual(
          provenance_by_production[production_1],
          production_composition.ProductionProvenance(source=production_1))
      self.assertEqual(
          provenance_by_production[production_2],
          production_composition.ProductionProvenance(source=production_2))

    with self.subTest("should_record_composed_production_provenance"):
      self.assertEqual(
          provenance_by_production[production_01],
          production_composition.ProductionProvenance(
              source=production_0, compositions=((production_1, 0),)))
      self.assertEqual(
          provenance_by_production[production_234],
          production_composition.ProductionProvenance(
              source=production_2,
              compositions=((production_3, 0), (production_4, 1))))

    with self.subTest("should_correctly_shift_indices_in_spliced_provenance"):
      # Note that the indices coming from compositions of production_234's
      # provenance need to be shifted.
      self.assertEqual(
          provenance_by_production[composed_production],
          production_composition.ProductionProvenance(
              source=production_0,
              compositions=((production_1, 0), (production_2, 2),
                            (production_3, 0 + 2), (production_4, 1 + 2))))

  def test_replace_should_correctly_replace_source(self):
    production_string_0 = "C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    new_production_string_0 = (
        "C[sem=(?x2+AFTER+?x1)] -> S[sem=?x1] 'and' S[sem=?x2]")
    new_production_0 = nltk_utils.production_from_production_string(
        new_production_string_0)

    provenance_by_production = production_composition.ProductionProvenanceDict()

    provenance = production_composition.ProductionProvenance(
        source=production_0,
        compositions=((production_1, 0), (production_1, 2)))

    new_provenance = provenance.replace(production_0, new_production_0,
                                        provenance_by_production)
    production_with_new_provenance = new_provenance.get_production()

    expected_provenance = production_composition.ProductionProvenance(
        source=new_production_0,
        compositions=((production_1, 0), (production_1, 2)))

    with self.subTest(
        "should_construct_new_provenance_with_production_replaced"):
      self.assertEqual(new_provenance, expected_provenance)

    with self.subTest("should_record_new_provenance"):
      self.assertIn(production_with_new_provenance, provenance_by_production)
      self.assertEqual(provenance_by_production[production_with_new_provenance],
                       new_provenance)

  def test_replace_should_correctly_replace_production_in_compositions(self):
    production_string_0 = "C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]"
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    new_production_string_1 = "S[sem=JUMP] -> 'walk'"
    new_production_1 = nltk_utils.production_from_production_string(
        new_production_string_1)

    provenance_by_production = production_composition.ProductionProvenanceDict()

    provenance = production_composition.ProductionProvenance(
        source=production_0,
        compositions=((production_1, 0), (production_1, 2)))

    new_provenance = provenance.replace(production_1, new_production_1,
                                        provenance_by_production)
    production_with_new_provenance = new_provenance.get_production()

    expected_provenance = production_composition.ProductionProvenance(
        source=production_0,
        compositions=((new_production_1, 0), (new_production_1, 2)))

    with self.subTest(
        "should_construct_new_provenance_with_production_replaced"):
      self.assertEqual(new_provenance, expected_provenance)

    with self.subTest("should_record_new_provenance"):
      self.assertIn(production_with_new_provenance, provenance_by_production)
      self.assertEqual(provenance_by_production[production_with_new_provenance],
                       new_provenance)

  def test_replace_should_raise_error_if_different_rhs(self):
    production_string_0 = ("C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]")
    production_0 = nltk_utils.production_from_production_string(
        production_string_0)
    production_string_1 = "S[sem=WALK] -> 'walk'"
    production_1 = nltk_utils.production_from_production_string(
        production_string_1)
    provenance = production_composition.ProductionProvenance(
        source=production_0, compositions=((production_1, 0),))

    new_production_string_1 = "S[sem=WALK] -> 'jump'"
    new_production_1 = nltk_utils.production_from_production_string(
        new_production_string_1)

    provenance_by_production = production_composition.ProductionProvenanceDict()

    with self.assertRaisesRegex(
        ValueError,
        "Old production and new production do not have the same RHS."):
      provenance.replace(production_1, new_production_1,
                         provenance_by_production)

  def test_get_production_and_json_serialization_roundtrips(self):
    # Here we use a minimal grammar to save time but still capture the shape of
    # the SCAN grammar.
    grammar_string = """% start C
    C[sem=(?x1+?x2)] -> S[sem=?x1] 'and' S[sem=?x2]
    C[sem=?x1] -> S[sem=?x1]
    S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'
    S[sem=?x1] -> V[sem=?x1]
    V[sem=(?x2+?x2+?x1)] -> U[sem=?x1] 'opposite' W[sem=?x2]
    V[sem=?x1] -> D[sem=?x1]
    D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
    D[sem=?x1] -> U[sem=?x1]
    U[sem='WALK'] -> 'walk'
    W[sem='RTURN'] -> 'right'"""
    grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)

    provenance_by_production = production_composition.ProductionProvenanceDict()
    inference_engine = inference.InferenceEngine(
        provenance_by_production=provenance_by_production)

    # We use the inference engine to automate the process of creating all
    # possible compositions.  With the grammar above there are 693 productions
    # in total.
    for source_production in grammar.productions():
      inference_engine.add_production(source_production)

    with self.subTest("should_record_correct_source_productions"):
      self.assertCountEqual(provenance_by_production.get_source_productions(),
                            grammar.productions())

    with self.subTest("get_production_should_recover_production"):
      for production in inference_engine.all_productions:
        provenance = provenance_by_production[production]
        production_with_provenance = provenance.get_production()
        self.assertEqual(production, production_with_provenance)

    with self.subTest("json_serialization_roundtrip_should_recover_provenance"):
      for production in inference_engine.all_productions:
        provenance = provenance_by_production[production]
        provenance_json = provenance.to_json()
        recovered_provenance = (
            production_composition.ProductionProvenance.from_json(
                provenance_json))
        self.assertEqual(provenance, recovered_provenance)


if __name__ == "__main__":
  absltest.main()
