# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable

from absl.testing import absltest
from absl.testing import parameterized
import nltk

from conceptual_learning.cscan import parsing


def _create_grammar(
    start_symbol,
    grammar_rules):
  """"Returns a FeatureGrammar with the given contents for test purposes."""
  grammar_lines = [f'% start {start_symbol}']
  grammar_lines.extend(grammar_rules)
  return nltk.grammar.FeatureGrammar.fromstring('\n'.join(grammar_lines))


class RuleTrackingChartTest(absltest.TestCase):

  def test_parses_with_rules_returns_pair_of_parse_tree_and_rules_used(self):
    rules = [
        'D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]',
        "U[sem='WALK'] -> 'walk'",
        "W[sem='LTURN'] -> 'left'",
        "W[sem='SOME_IRRELEVANT_OUTPUT'] -> 'some_irrelevant_input'",
    ]
    grammar = _create_grammar('D', rules)
    tokenized_sentence = ['walk', 'left']

    # First parse the sentence without rule tracking, for comparison.
    original_parser = nltk.parse.FeatureEarleyChartParser(grammar)
    original_chart = original_parser.chart_parse(tokenized_sentence)
    original_trees = list(original_chart.parses(grammar.start()))
    original_tree = original_trees[0]

    # Now parse the same sentence using a RuleTrackingChart.
    rule_tracking_parser = nltk.parse.FeatureEarleyChartParser(
        grammar, chart_class=parsing.RuleTrackingChart)
    rule_tracking_chart = rule_tracking_parser.chart_parse(tokenized_sentence)
    trees_with_rules = list(
        rule_tracking_chart.parses_with_rules(grammar.start()))
    returned_tree, returned_rules = trees_with_rules[0]

    with self.subTest('Returns same number of parses as chart_parse'):
      self.assertLen(trees_with_rules, len(original_trees))
    with self.subTest('In this example, there should only be one parse'):
      self.assertLen(trees_with_rules, 1)
    with self.subTest('Returns same parse tree as chart_parse'):
      self.assertEqual(original_tree, returned_tree)
    with self.subTest('Returns the set of rules used'):
      self.assertCountEqual(list(rules[0:3]), returned_rules)


class SemanticParseTest(parameterized.TestCase):

  def test_returns_semantics_plus_only_the_rules_relevant_to_the_parse(self):
    relevant_rules = [
        'D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]',
        "U[sem='WALK'] -> 'walk'",
        "W[sem='LTURN'] -> 'left'",
    ]
    irrelevant_rules = [
        'D[sem=?x1] -> U[sem=?x1]',
        "U[sem=('LOOK')] -> 'look'",
    ]
    parser = parsing.RuleTrackingSemanticParser(
        _create_grammar('D', relevant_rules + irrelevant_rules))
    semantics_with_rules = parser.semantic_parse(['walk', 'left'])
    self.assertCountEqual([('LTURN WALK', frozenset(relevant_rules))],
                          semantics_with_rules)

  def test_multiple_parses_with_different_semantics(self):
    """Should return a separate tuple for each parse."""
    rules = [
        'D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]',
        'D[sem=(?x1+?x2)] -> U[sem=?x1] W[sem=?x2]',
        "U[sem='WALK'] -> 'walk'",
        "W[sem='LTURN'] -> 'left'",
    ]
    parser = parsing.RuleTrackingSemanticParser(_create_grammar('D', rules))
    semantics_with_rules = parser.semantic_parse(['walk', 'left'])
    self.assertCountEqual([
        ('LTURN WALK', frozenset([rules[0], rules[2], rules[3]])),
        ('WALK LTURN', frozenset([rules[1], rules[2], rules[3]])),
    ], semantics_with_rules)

  def test_multiple_parses_with_same_semantics_but_different_final_edge(self):
    """Should return a separate tuple for each parse."""
    rules = [
        'V[sem=(?x1+?x2)] -> D[sem=?x1] U[sem=?x2]',
        'V[sem=(?x1+?x2)] -> U[sem=?x1] D[sem=?x2]',
        'D[sem=(?x1+?x2)] -> U[sem=?x1] U[sem=?x2]',
        "U[sem='WALK'] -> 'walk'",
        "U[sem='LOOK'] -> 'look'",
    ]
    parser = parsing.RuleTrackingSemanticParser(_create_grammar('V', rules))
    semantics_with_rules = parser.semantic_parse(['walk', 'look', 'walk'])
    self.assertCountEqual([
        ('WALK LOOK WALK', frozenset(rules[0:1] + rules[2:5])),
        ('WALK LOOK WALK', frozenset(rules[1:5])),
    ], semantics_with_rules)

  def test_multiple_parses_with_same_semantics_and_same_final_edge(self):
    """Should return a separate tuple for each parse.

    However, currently, the rules associated with each tuple are the union of
    the relevant rules of the two parses.

    """
    rules = [
        'S[sem=?x1] -> V[sem=?x1]',
        'V[sem=(?x1+?x2)] -> D[sem=?x1] U[sem=?x2]',
        'V[sem=(?x1+?x2)] -> U[sem=?x1] D[sem=?x2]',
        'D[sem=(?x1+?x2)] -> U[sem=?x1] U[sem=?x2]',
        "U[sem='WALK'] -> 'walk'",
        "U[sem='LOOK'] -> 'look'",
    ]
    parser = parsing.RuleTrackingSemanticParser(_create_grammar('S', rules))
    semantics_with_rules = parser.semantic_parse(['walk', 'look', 'walk'])
    self.assertCountEqual([('WALK LOOK WALK', frozenset(rules)),
                           ('WALK LOOK WALK', frozenset(rules))],
                          semantics_with_rules)

  def test_raises_error_if_featurestruct_lacks_sem_tag(self):
    parser = parsing.RuleTrackingSemanticParser(
        _create_grammar('U', ["U[wrongtag='WALK'] -> 'walk'"]))
    with self.assertRaisesRegex(ValueError, 'FeatStruct lacking sem tag'):
      list(parser.semantic_parse(['walk']))

  def test_does_not_insert_unnecessary_spaces_in_single_token_semantics(self):
    rules = ["U[sem='LOOK'] -> 'look'"]
    parser = parsing.RuleTrackingSemanticParser(_create_grammar('U', rules))
    semantics_with_rules = parser.semantic_parse(['look'])
    self.assertCountEqual([('LOOK', frozenset(rules))], semantics_with_rules)


if __name__ == '__main__':
  absltest.main()
