# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for manipulating production trees."""

import dataclasses
import functools
import itertools
import re
from typing import Dict, List, Optional, Sequence, Set, Tuple

import networkx as nx
import nltk

from conceptual_learning.cscan import enums
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import rule_conversion

# The string used as a placeholder for 'self' in an outer substitution string.
OUTER_SUBSTITUTION_PLACEHOLDER = '__'

# Map of variable name (e.g. 'x1') to set of substitutions.
VariableSubstitutionsMap = Dict[str, Set[str]]

VariableSubstitutionsByProduction = Dict[nltk.grammar.Production,
                                         VariableSubstitutionsMap]
OuterSubstitutionsByProduction = Dict[nltk.grammar.Production, Set[str]]

VariableSubstitutionsByRule = Dict[str, VariableSubstitutionsMap]
OuterSubstitutionsByRule = Dict[str, Set[str]]

# A value representing "infinite" for use when counting variable substitutions
# and outer substitutions, e.g., when a variable appears in unsubstituted form
# (which is equivalent to illustrating all possible substitutions) or when a
# rule appears at the top of a rule application tree (which counts as evidence
# that the rule holds in all possible outer contexts). The choice of specific
# value to use here is arbitrary, as long as it is much bigger than any value
# that would occur naturally. We prefer this value over sys.maxsize
# (int: 9223372036854775807) for readability, and over math.inf (float: inf) for
# type compatibility.
INFINITE_SUBSTITUTIONS = 1000000


def _get_token_index_by_variable_index(
    production):
  """Returns a sequence of rhs token indices indexed by variable index.

  Args:
    production: The production to be analyzed.

  Returns:
    Sequence in which the ith element is the rhs token index of the ith variable
    in the production.
  """
  token_index_by_variable_index = []
  for token_index, rhs_element in enumerate(production.rhs()):
    if isinstance(rhs_element, nltk.grammar.Nonterminal):
      token_index_by_variable_index.append(token_index)
  return token_index_by_variable_index


def _get_composition_index_mapping_for_single_production(
    production_tree
):
  """Returns a composition index mapping for a single production tree node.

  A composition index mapping is a mapping of composition indices of the form
  used in ProductionProvenance (which represent token indices in the rhs of a
  partially composed production) to the corresponding ProductionTree node and
  child index of that node.

  This function returns a simple composition index mapping for a single
  production (the top production in the given `production_tree`).

  Args:
    production_tree: The production tree node corresponding to the single
      production for which we are to construct the mapping. Any child nodes will
      be ignored.

  Returns:
    List in which the ith element contains the (ProductionTree, child index)
    pair corresponding to the ith rhs token. As this list contains at this stage
    only the mappings for a single source production, all of the elements of the
    list will point to the same ProductionTree, with only the child index
    varying.
  """
  composition_index_mapping = []
  variable_index = 0
  for rhs_element in production_tree.production.rhs():
    if isinstance(rhs_element, nltk.grammar.Nonterminal):
      # Variable
      composition_index_mapping.append((production_tree, variable_index))
      variable_index += 1
    else:
      # Non-variable
      composition_index_mapping.append(None)
  return composition_index_mapping


@dataclasses.dataclass
class ProductionTree:
  """Production provenance represented as a tree of source productions.

  By source productions, we mean here the productions of the grammar that the
  context was based on, which can include the productions of hidden rules,
  unreliable rules, and distractor rules, in addition to explicitly asserted
  rules. The production trees of the various context examples can thus be used
  to determine the number of unique variable substitutions and outer
  substitutions that were illustrated in the context for a given hidden,
  unreliable, or distractor rule.

  Attributes:
    production: The source production applied at this node.
    children: The production trees of the expressions that are substituted for
      each of the variables in `rule`, or None if the variable was not replaced.
      E.g., if a production contains variables `?x1' and '?x2', then
      `children[0]` represents the expression that is substituted for `?x1`,
      while `children[1]` corresponds to `?x2`. If only `?x2` was replaced, then
      `children[0]` should equal None, so that `children[1]` still represents
      the expression substituted for `?x2`.
    graph: A directed graph representation of the tree. Each node corresponds to
      a production, and each edge corresponds to a composition. The direction of
      the edge goes from parent to child.
  """

  production: nltk.grammar.Production
  children: List[Optional['ProductionTree']] = dataclasses.field(
      default_factory=list)

  @classmethod
  def from_production(cls,
                      production):
    """Returns a ProductionTree containing just a single production.

    Args:
      production: The production to store in the tree.

    Returns:
      ProductionTree with the given production and with children pre-populated
      with a None placeholder for each of the production's variables.
    """
    num_variables = production_composition.num_variables_from_production(
        production)
    return ProductionTree(
        production=production, children=[None for _ in range(num_variables)])

  @classmethod
  def from_production_provenance(
      cls, production_provenance
  ):
    """Returns a ProductionTree with the same content as the provenance."""
    production_tree = ProductionTree.from_production(
        production_provenance.source)

    # Mapping of composition indices from `production_provenance` (which
    # represent token indices in the rhs of a partially composed production) to
    # the corresponding ProductionTree node and child index of that node.
    # Indices of non-variable tokens are mapped to `None`. At first there is
    # only one ProductionTree we are dealing with, but as we proceed
    # through the compositions, there could be multiple different ProductionTree
    # nodes with free variables at any given time.
    composition_index_mapping: List[Optional[Tuple[ProductionTree, int]]] = (
        _get_composition_index_mapping_for_single_production(production_tree))

    for composition_number, (production, index) in enumerate(
        production_provenance.compositions):
      # Determine where to apply the composition in the production tree.
      if index >= len(composition_index_mapping):
        raise ValueError(
            f'Composition index not found in composition_index_mapping:\n'
            f'  composition number={composition_number}\n'
            f'  composition=({production}, {index})\n'
            f'  len(composition_index_mapping)='
            f'{len(composition_index_mapping)}\n'
            f'  composition_index_mapping={composition_index_mapping}')
      if composition_index_mapping[index] is None:
        raise ValueError(
            f'Composition index does not correspond to a variable token:\n'
            f'  composition number={composition_number}\n'
            f'  composition=({production}, {index})\n'
            f'  composition_index_mapping={composition_index_mapping}')
      parent_tree_node, child_index = composition_index_mapping[index]
      if child_index >= len(parent_tree_node.children):
        raise ValueError(
            f'Child index in composition_index_mapping too big for number of '
            f'children in production:'
            f'  parent_tree_node={parent_tree_node}\n'
            f'  child_index={child_index}\n')

      while len(parent_tree_node.children) <= child_index:
        parent_tree_node.children.append(None)

      # Applying the composition now simply means inserting the production
      # in the appropriate place in the tree.
      if parent_tree_node.children[child_index] is not None:
        raise ValueError(
            f'Attempting to insert a production into the ProductionTree in a '
            f'location where a child production already exists:\n'
            f'  parent_tree_node={parent_tree_node}\n'
            f'  child_index={child_index}\n'
            f'  Existing child={parent_tree_node.children[child_index]}\n')
      new_tree_node = ProductionTree.from_production(production)
      parent_tree_node.children[child_index] = new_tree_node

      # Finally, adjust the composition index mapping to account for the one
      # variable that was eliminated via the composition, as well as any new
      # variables or other tokens that have been introduced.
      new_mapping = list(composition_index_mapping[:index])
      new_mapping.extend(
          _get_composition_index_mapping_for_single_production(new_tree_node))
      new_mapping.extend(composition_index_mapping[index + 1:])
      composition_index_mapping = new_mapping

    return production_tree

  def get_child(self, i):
    """Returns the ith child tree node, or None if it doesn't exist."""
    return self.children[i] if i < len(self.children) else None

  def get_composed_production(self):
    """Returns the production formed by composing all the productions."""
    # We compose the children in reverse order so that we don't need to worry
    # about adjusting the index after each composition to account for the
    # resulting increase or decrease in the number of variables.
    production = self.production
    token_index_by_variable_index = _get_token_index_by_variable_index(
        production)
    for variable_index, child_tree in reversed(list(enumerate(self.children))):
      if variable_index >= len(token_index_by_variable_index):
        raise ValueError(
            f'Unable to compose production at variable index {variable_index}, '
            f'as there are only {len(token_index_by_variable_index)} variables '
            f'in parent production: {production}')
      composition_index = token_index_by_variable_index[variable_index]
      if child_tree:
        production = production_composition.compose(
            production, child_tree.get_composed_production(), composition_index)  # pytype: disable=attribute-error

    return production

  def get_input_string(self):
    """Returns the input string corresponding to the composed production's rhs.

    The format of the returned string is the same as that used in the
    interpretation rule format. May include variable names, if the composed
    production contains free variables.
    """
    return rule_conversion.interpretation_rule_input_string_from_production(
        self.get_composed_production())

  def get_variable_substitutions_by_production(
      self,
      variable_substitutions_by_production
  ):
    """Updates variable_substitutions_by_production recursively.

    In typical usage, a `variable_substitutions_by_production` data structure
    will be built for a context as a whole by calling this method iteratively
    with the same `variable_substitutions_by_production` but on different
    ProductionTrees, one for each context example. Initially the
    `variable_substitutions_by_production` will be empty, but by the end it will
    show for each source production all of the different variable
    substitutions that were illustrated for that production across all of the
    examples in the context.

    Args:
      variable_substitutions_by_production: This will be updated to include all
        the variable substitutions found for any production/variable combination
        in this production tree. If it already has contents, then the new
        contents will be merged with the old ones The structure is a mapping
        from each production to a mapping of its variable substitutions. The
        variable substitution mapping itself is a mapping from each of the
        production's variable names (e.g., 'x1') to a set of input phrases that
        the variable was substituted for in any of the examples that this method
        was called on.
    """
    # E.g., ['x1', 'and', 'x2']
    input_tokens = (
        rule_conversion.interpretation_rule_input_tokens_from_production(
            self.production))
    # E.g., [0, 2]
    token_index_by_variable_index = _get_token_index_by_variable_index(
        self.production)
    # E.g., ['x1': {...}, 'x2': {...}]
    variable_substitution_map = variable_substitutions_by_production.setdefault(
        self.production, {})

    # Process each of the variables in the current production.
    for variable_index, token_index in enumerate(token_index_by_variable_index):
      child_tree = self.get_child(variable_index)
      variable_name = input_tokens[token_index]
      substitutions = variable_substitution_map.setdefault(variable_name, set())
      if child_tree:
        # The variable was substituted: Update the substitutions for the current
        # production and then for the child production tree recursively.
        substitutions.add(child_tree.get_input_string())
        child_tree.get_variable_substitutions_by_production(
            variable_substitutions_by_production)
      else:
        # The variable was not substituted: Simply treat the variable name
        # itself as the substituted string.
        substitutions.add(variable_name)

  def get_variable_substitutions_by_rule(
      self, variable_substitutions_by_rule,
      rule_format):
    """Updates variable_substitutions_by_rule recursively.

    Like `get_variable_substitutions_by_production`, except that in the
    `variable_substitutions_by_rule` data structure, the keys of the outer
    mapping are rule strings of the specified rule format, rather than raw
    Productions. Productions that don't have a corresponding rule in the
    specified format will be omitted. (That means, for example, that when using
    interpretation rule format, PassThroughRules will be omitted.) This is the
    format that will actually be stored in ExampleSetMetadata.

    Args:
      variable_substitutions_by_rule: This will be updated to include all the
        variable substitutions found for any rule/variable combination in this
        production tree.  If it already has contents, then the new contents will
        be merged with the old ones. The structure is a mapping from each rule
        to a mapping of its variable substitutions. The variable substitution
        mapping itself is a mapping from each of the rule's variable names
        (e.g., 'x1') to a set of input phrases that the variable was substituted
        for in any of the examples that this method was called on. E.g., for the
        rule '[x1 twice] = …', if it is used in an example <'turn left twice
        after jump', …>, then the variable substitution for x1 would be 'turn
        left'. If the example were <x1 twice after jump = …', …>, then the
        variable substitution for x1 would be simply 'x1'.
     rule_format: The rule format in which to represent the rules.
    """
    variable_substitutions_by_production = {}
    self.get_variable_substitutions_by_production(
        variable_substitutions_by_production)

    # Here we convert each production to the requested rule format, and then
    # merge the production's variable substitution map into any existing
    # substitutions that we may already have for that rule.
    for production, new_variable_substitution_map in (
        variable_substitutions_by_production.items()):
      rule = rule_conversion.rule_from_production(production, rule_format)
      if not rule:
        # Skip over pass-through rules.
        continue
      variable_substitution_map = variable_substitutions_by_rule.setdefault(
          rule, {})  # pytype: disable=container-type-mismatch
      for variable_name, new_substitutions in (
          new_variable_substitution_map.items()):
        substitutions = variable_substitution_map.setdefault(
            variable_name, set())
        substitutions.update(new_substitutions)

  def _get_outer_substitutions_by_production_for_children(
      self, outer_substitutions_by_production
  ):
    """Updates outer_substitutions_by_production recursively for descendants."""
    for child_index, child_tree in reversed(list(enumerate(self.children))):
      if not child_tree:
        continue

      # Here we temporarily replace the child with a trivial ProductionTree
      # containing OUTER_SUBSTITUTION_PLACEHOLDER as its only rhs terminal,
      placeholder_child_production = nltk.grammar.Production(
          lhs=child_tree.production.lhs(), rhs=[OUTER_SUBSTITUTION_PLACEHOLDER])
      placeholder_child_tree = ProductionTree(placeholder_child_production)
      self.children[child_index] = placeholder_child_tree

      # Now can we use the input string of the current production (i.e., parent
      # production) as the outer substitution.
      outer_substitutions = outer_substitutions_by_production.setdefault(
          child_tree.production, set())  # pytype: disable=attribute-error
      outer_substitutions.add(self.get_input_string())

      # Revert the temporary changes and continue recursively.
      self.children[child_index] = child_tree
      child_tree._get_outer_substitutions_by_production_for_children(
          outer_substitutions_by_production)

      # pytype: enable=attribute-error

  def get_outer_substitutions_by_production(
      self, outer_substitutions_by_production
  ):
    """Updates outer_substitutions_by_production for self and descendants.

    In typical usage, a `outer_substitutions_by_production` data structure will
    be built for a context as a whole by calling this method iteratively with
    the same `outer_substitutions_by_production` but on different
    ProductionTrees, one for each context example. Initially the
    `outer_substitutions_by_production` will be empty, but by the end it will
    show for each source production all of the different outer substitutions
    that were illustrated for that production across all of the examples in the
    context.

    Args:
      outer_substitutions_by_production: This will be updated to include all the
        outer substitutions found for any production in this production tree. If
        it already has contents, then the new contents will be merged with the
        old ones. The structure is a mapping from each production to a set of
        its "outer substitutions". We define the "outer substitution" to be the
        string corresponding to the node immediately above in the production
        tree of the example as a whole (skipping over pass-through rule nodes),
        with the substring corresponding to the current production replaced with
        '__'. If the given production is the topmost non-pass-through production
        in the production tree of the example as a whole, then we define the
        outer substitution to be simply '__'.
    """
    outer_substitutions = outer_substitutions_by_production.setdefault(
        self.production, set())
    outer_substitutions.add(OUTER_SUBSTITUTION_PLACEHOLDER)
    self._get_outer_substitutions_by_production_for_children(
        outer_substitutions_by_production)

  def get_outer_substitutions_by_rule(
      self, outer_substitutions_by_rule,
      rule_format):
    """Updates outer_substitutions_by_rule recursively.

    Like `get_outer_substitutions_by_production`, except that in the
    `outer_substitutions_by_rule` data structure, the keys are rule strings of
    the specified rule format, rather than raw Productions. Productions that
    don't have a corresponding rule in the specified format will be omitted.
    (That means, for example, that when using interpretation rule format,
    PassThroughRules will be omitted.) This is the format that will actually be
    stored in ExampleSetMetadata.

    Args:
      outer_substitutions_by_rule: This will be updated to include all the outer
        substitutions found for any rule in this production tree. If it already
        has contents, then the new contents will be merged with the old ones.
        The structure is a mapping from each rule to a set of its "outer
        substitutions". We define the "outer substitution" to be the string
        corresponding to the node immediately above in the rule tree of the
        example as a whole (skipping over pass-through rule nodes), with the
        substring corresponding to the current rule replaced with '__'. If the
        given rule is the topmost non-pass-through rule in the rule tree of the
        example as a whole, then we define the outer substitution to be simply
        '__'. E.g., in the example 'turn left twice after jump', the outer
        substitution for '[x1 twice] = …' would be '__ after jump', while the
        outer substitution for '[turn] = …' would be '__ left', and the outer
        substitution for '[x1 after x2] = …' would be '__'.
      rule_format: The rule format in which to represent the rules.
    """
    outer_substitutions_by_production = {}
    self.get_outer_substitutions_by_production(
        outer_substitutions_by_production)

    # Here we convert each production to the requested rule format, and then
    # merge the production's outer substitution map into any existing
    # substitutions that we may already have for that rule.
    for production, new_outer_substitutions in (
        outer_substitutions_by_production.items()):
      rule = rule_conversion.rule_from_production(production, rule_format)
      if not rule:
        # Skip over pass-through rules.
        continue
      outer_substitutions = outer_substitutions_by_rule.setdefault(rule, set())  # pytype: disable=container-type-mismatch
      outer_substitutions.update(new_outer_substitutions)

  @functools.cached_property
  def graph(self):
    """Returns a graph representation of the production tree."""
    graph = nx.DiGraph()

    # Since the same production could appear multiple times in the production
    # tree, we use the unique path (the tuple of indices) from the root as
    # graph nodes.  The empty tuple is the path to the root.
    graph.add_node(())
    for i, child in enumerate(self.children):
      if child is not None:
        child_graph = child.graph
        for child_u in child_graph.nodes:
          # These nodes are paths from the child graph's root, so we prepend the
          # child index to form the path from the root of the parent's root.
          graph.add_node((i,) + child_u)

        graph.add_edge((), (i,))
        for child_u, child_v in child_graph.edges:
          graph.add_edge((i,) + child_u, (i,) + child_v)

    return graph

  def production_from_path(self, path):
    if not path:
      return self.production
    child = self.children[path[0]]
    return child.production_from_path(path[1:])

  def subtree_from_subgraph_and_root(self, subgraph,
                                     root):
    """Returns the production subtree of the subgraph at the specified root.

    Args:
      subgraph: A subgraph of self.graph.
      root: A node in the subgraph specifying the root of the desired subtree.
    """
    subtree = ProductionTree.from_production(self.production_from_path(root))

    # The nodes in the graph are full paths from self's root.
    for child_path in subgraph.successors(root):
      child_index = child_path[-1]
      subtree.children[child_index] = self.subtree_from_subgraph_and_root(
          subgraph, child_path)

    return subtree

  def get_subtrees_containing_root(self, size):
    """Returns production subtrees containing the root of self.

    Each element in the return list is a production tree of the specified size
    whose root is the same as self.production.

    Args:
      size: The size of the returned subtrees.
    """
    if not size:
      return []

    if size == 1:
      return [ProductionTree.from_production(self.production)]

    root = ()
    # Reference: https://stackoverflow.com/a/65870702
    # Focus on only nodes close enough to root.
    ego_graph = nx.generators.ego_graph(
        self.graph, root, radius=size - 1, center=False)
    result = []
    # This is not very efficient, but the code is easy to understand, and in
    # our use case the graphs are small.
    for nodes in itertools.combinations(ego_graph, size - 1):
      nodes_including_root = (root,) + nodes
      subgraph = self.graph.subgraph(nodes_including_root)
      if nx.is_weakly_connected(subgraph):
        result.append(self.subtree_from_subgraph_and_root(subgraph, root))

    return result

  def get_all_subtrees(self, size):
    """Returns all production subtrees of the specified size."""
    if not size:
      return []

    result = self.get_subtrees_containing_root(size)
    for child in self.children:
      if child is not None:
        result.extend(child.get_all_subtrees(size))

    return result


def get_effective_num_outer_substitutions(outer_substitutions):
  """Returns the effective # of substitutions for use in the inductive bias.

  Normally this is just the length of the set of substitutions, except that if
  the rule was used at least once as the top rule in the rule application tree
  (i.e., with empty outer substitution), that counts as an infinite number of
  substitutions (i.e., as evidence that the rule is valid in all possible outer
  contexts).

  Args:
    outer_substitutions: The set of outer substitutions of the rule of interest.
  """
  if OUTER_SUBSTITUTION_PLACEHOLDER in outer_substitutions:
    return INFINITE_SUBSTITUTIONS
  else:
    return len(outer_substitutions)


def _get_effective_num_variable_substitutions(
    variable_substitutions):
  if any(re.fullmatch(r'x[0-9]*', s) for s in variable_substitutions):
    return INFINITE_SUBSTITUTIONS
  else:
    return len(variable_substitutions)


def get_effective_min_num_variable_substitutions(
    substitutions_map):
  """Returns the effective min # of substitutions for use in the inductive bias.

  Normally this is just the minimum of the length of the set of substitutions
  across the various variables of the rule, except that if a variable appears
  at least in unsubstituted form, that counts as an infinite number of
  substitutions (i.e., as evidence that the rule holds for all possible
  substitutions of the given variable).

  Args:
    substitutions_map: Mapping of variable name to the set of variable
      substitutions for all variables in the rule of interest.
  """
  return min([
      _get_effective_num_variable_substitutions(substitutions)
      for substitutions in substitutions_map.values()
  ],
             default=INFINITE_SUBSTITUTIONS)
