# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for parsing with FeatureGrammars in a conceptual learning setting.

The main entry point to this library is the RuleTrackingSemanticParser class.
RuleTrackingChart could theoretically be used directly as well, but would
typically be used only behind-the-scenes inside of RuleTrackingSemanticParser.
"""

import collections
import itertools
from typing import (AbstractSet, Any, Iterator, MutableSequence, MutableSet,
                    Sequence, Tuple, Type)

import nltk
from nltk import parse
from nltk.featstruct import FeatStruct
from nltk.featstruct import TYPE
from nltk.parse.chart import EdgeI
from nltk.parse.featurechart import FeatureTreeEdge

# Special FeatStruct tag used for representing the semantics of the input.
_SEMANTICS = 'sem'


def _edge_to_rule_string(edge):
  """Returns a string representation consistent with that used in grammar rules.

  In particular, if the edge was created directly from a grammar rule parsed
  from an fcfg-style string format, _edge_to_rule_string(edge) should exactly
  equal the original grammar rule string from which the edge was created.

  This differs from the behavior of str(EdgrI) or repr(EdgeI), for example,
  because those functions output additional information such as a '*' symbol
  indicating the position of the edge's "dot".

  Args:
    edge: Edge to convert to string format.
  """
  return f'{edge.lhs()!r} -> {" ".join(repr(term) for term in edge.rhs())}'


class RuleTrackingChart(parse.earleychart.FeatureIncrementalChart):
  """FeatureIncrementalChart that tracks the grammar rules used in parsing.

  This is intended primarily for use with RuleTrackingSemanticParser, although
  it can be used with an ordinary FeatureEarleyChartParser as well.

  Typical usage example:

    grammar = ...
    tokenized_sentence = ...
    parser = nltk.parse.FeatureEarleyChartParser(
        grammar, chart_class=parsing.RuleTrackingChart)
    chart = rule_tracking_parser.chart_parse(tokenized_sentence)
    trees_with_rules = chart.parses_with_rules(grammar.start()))
  """

  def initialize(self):
    super().initialize()

    # Mapping of each edge to its "previous" edges. This allows us to take an
    # arbitrary edge in a final parse tree, such as the following (based on the
    # example in http://www.nltk.org/howto/featgram.html):
    #    [Edge: [0:2] DP[AGR=[GND='f', NUM='pl', PERS=3], SEM=(GIRL)] ->
    #        D[AGR=[GND='f', NUM='pl', PERS=3], SEM=()] N[AGR=[GND='f',
    #        NUM='pl', PERS=3], SEM=(GIRL)] *],
    # and track backward to find the edge corresponding to the original grammar
    # rule that was responsible for the construction of the parse of that text
    # range, prior to unification, e.g.:
    #    [Edge: [0:0] DP[AGR=?a, SEM=(?s1+?s2)] -> * D[AGR=?a, SEM=?s1]
    #        N[AGR=?a, SEM=?s2] {}],
    #
    # In the case where the same edge could be reached via multiple parse
    # routes, there may be multiple "previous" edges -- one for each route.
    #
    # Note that the ancestor class nltk.Chart already tracks the relationship
    # between each edge and its "child" edges, which in the above example would
    # be the edges corresponding to sub-ranges such as [0:1], [1:2], etc.
    #
    # By recursively traversing the graph of previous and child edge pointers,
    # starting from the root edge of a parse tree, we can eventually reach all
    # of the edges corresponding to the original grammar rules that contributed
    # to a given parse (or to any of the possible parses that would lead to this
    # same edge).
    self._edge_to_previous_edges = collections.defaultdict(list)

  def _previous_pointers(self, edge):
    """Returns the previous edges of the given edge.

    Analogous to nltk.Chart.child_pointer_lists, but returns individual previous
    edges rather than lists of child edges.

    Args:
      edge: The edge for which to retrieve the previous edges.
    """
    return self._edge_to_previous_edges[edge]

  def _child_and_previous_edges(self, edge):
    """Returns an iterator over all child and previous edges.

    In the case where the given edge could be constructed via multiple possible
    parse routes, the resulting iterator will include any edge that served as a
    child or previous edge in any of those parses.

    Args:
      edge: The edge for which to retrieve the child and previous edges.
    """
    return itertools.chain(
        itertools.chain.from_iterable(self.child_pointer_lists(edge)),
        self._previous_pointers(edge))

  def insert_with_backpointer(self, new_edge, previous_edge,
                              child_edge):
    """Adds a new edge to the chart, using a pointer to the previous edge.

    This is an override of a method defined by nltk.Chart that is called
    internally by the NLTK parsing framework during parsing, e.g., by
    FeatureFundamentalRule (parent class of FeatureScannerRule and
    FeatureCompleterRule that are used in the parsing strategy of
    FeatureEarleyChartParser). It is not meant to be called directly by any
    code written explicitly with RuleTrackingChart in mind.

    Args:
      new_edge: New edge being inserted.
      previous_edge: Previous edge from which new_edge was created.
      child_edge: Child edge from which new_edge was created.

    Returns:
      Whether the operation modified the chart.
    """
    self._previous_pointers(new_edge).append(previous_edge)
    return super().insert_with_backpointer(new_edge, previous_edge, child_edge)

  def _is_rule_edge(self, edge):
    # In a bottom-up chart parsing algorithm, grammar rules are introduced
    # via self-loop edges (leaf edges associated with spans of zero length) at
    # the beginning of the text range that the rule will eventually cover.
    # For details, see: http://www.nltk.org/book_1ed/ch08-extras.html
    return (edge.start() == edge.end() and
            not any(self._child_and_previous_edges(edge)))

  def _gather_rule_edges(self, edge,
                         rule_edges):
    for other_edge in self._child_and_previous_edges(edge):
      self._gather_rule_edges(other_edge, rule_edges)
    if self._is_rule_edge(edge):
      rule_edges.add(edge)

  def _rule_edges(self, edge):
    """Returns the edges representing the grammar rules used in a parse.

    Args:
      edge: Edge representing the root of a parse tree.
    """
    rule_edges = set()
    self._gather_rule_edges(edge, rule_edges)
    return rule_edges

  def _rules(self, edge):
    """Returns string representations of the grammar rules used in a parse.

    Args:
      edge: Edge representing the root of a parse tree.

    Returns:
      Set of grammar rules in the format supported by FeatureGrammar.fromstring.
    """

    return frozenset(
        _edge_to_rule_string(rule_edge) for rule_edge in self._rule_edges(edge))

  def parses_with_rules(
      self,
      start,
      tree_class = nltk.Tree
  ):
    """Yields all valid parse trees together with the rules used for each.

    Variation of FeatureChart.parses that returns Tuples of a parse tree
    and a set of grammar rules used, rather than the parse trees alone.

    Args:
      start: Root node type of parse tree. See FeatureChart.parses.
      tree_class: Class of the tree to return. See FeatureChart.parses.
    """
    for edge in self.select(start=0, end=self._num_leaves):
      if ((isinstance(edge, FeatureTreeEdge)) and
          (edge.lhs()[TYPE] == start[TYPE]) and  # pytype: disable=unsupported-operands
          (nltk.featstruct.unify(edge.lhs(), start, rename_vars=True))):
        for tree in self.trees(edge, complete=True, tree_class=tree_class):
          yield tree, self._rules(edge)


class RuleTrackingSemanticParser(nltk.parse.FeatureEarleyChartParser):
  """FeatureGrammar-based semantic parser that tracks rules used in parsing."""

  def __init__(self, grammar, **parser_args):
    super().__init__(grammar, chart_class=RuleTrackingChart, **parser_args)

  def semantic_parse(
      self, tokenized_sentence
  ):
    """Determines the semantics corresponding to the given input sentence.

    Requires that the grammar output a FeatStruct containing a special
    tag called 'sem' that returns either a string or iterable of strings
    representing the semantics of the input sentence.

    Examples of grammar rules that meet these requirements:
      D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]
      U[sem='WALK'] -> 'walk'

    Args:
      tokenized_sentence: Terminal tokens representing the input for which we
        are to determine the corresponding semantics.

    Yields:
      The expected semantics in string format, paired with a set of rules used.
    """
    # Based on example usages of FeatureEarleyChartParser in:
    #   http://www.nltk.org/howto/featgram.html
    # and example usages of FeaturesGrammars for semantic parsing in:
    #   https://www.nltk.org/book/ch10.html#ex-sem3
    chart = self.chart_parse(tokenized_sentence)
    trees_with_rules = chart.parses_with_rules(self.grammar().start())
    for parse_tree, rules in trees_with_rules:
      if _SEMANTICS not in parse_tree.label():
        raise ValueError(f'FeatStruct lacking {_SEMANTICS} tag: '
                         f'{parse_tree.label()}')
      semantics = parse_tree.label()[_SEMANTICS]
      if not isinstance(semantics, str):
        semantics = ' '.join(s for s in semantics if s)
      yield semantics, rules
