# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for composing FeatureGrammar productions.

The functions in this module apply only to FeatureGrammar productions with the
correct "sem" feature as those generated by grammar_generation.  See comments
in the grammar_representation module for a full list of restrictions of a valid
FeatureGrammar considered in Conceptual SCAN.

Two productions, parent and other_parent, are composable if the LHS of
other_parent has the same grammar category as one of the nonterminals in
parent's RHS.  The composition at that item is then obtained by:
(1) Replace that item in parent's RHS with other_parent's RHS.
(2) Substitute the occurrences of that item's variable in parent's LHS "sem"
    feature with the "sem" feature of other_parent's LHS.

Example:
Let:
parent       = C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]
other_parent = S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'

Composing at the first S-nonterminal gives:
C[sem=(?x2+?x1+?x1)] -> V[sem=?x1] 'twice' 'after' S[sem=?x2]

Composing at the second S-nonterminal gives:
C[sem=(?x2+?x2+?x1)] -> S[sem=?x1] 'after' V[sem=?x2] 'twice'

"""

import collections
import dataclasses
import functools
import itertools
from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple

import nltk

from conceptual_learning.cscan import nltk_utils

ProductionProvenanceMapping = MutableMapping[nltk.grammar.Production,
                                             'ProductionProvenance']


class ProductionProvenanceDict(collections.UserDict):
  """The mapping from productions to production provenances."""

  def __setitem__(self, key,
                  item):
    production_with_provenance = item.get_production()
    if key != production_with_provenance:
      raise ValueError(
          f'Provenance and production mismatch.  Production: {key}, '
          f'provenance has {production_with_provenance}')

    self.data[key] = item

  def production_is_source_production(
      self, production):
    return (production in self.data and
            self.data[production] == ProductionProvenance(source=production))

  def get_source_productions(self):
    return [
        production for production in self.data
        if self.production_is_source_production(production)
    ]


class ProductionProvenancesDict(ProductionProvenanceDict):
  """Mapping from a production to one or more provenances.

  Warning: Unlike ProductionProvenanceDict above, this class doesn't strictly
    enforce that the added provenances lead to its corresponding production key,
    as that enforcement check was found to slow down calculation of the
    consistency metric by 60%.
  """

  def __setitem__(self, key,
                  item):
    self.data[key] = set()  # Clear the old value.
    for provenance in item:
      self.add_provenance(key, provenance)

  def add_provenance(self, key,
                     provenance):
    self.data.setdefault(key, set())
    if provenance not in self.data[key]:
      self.data[key].add(provenance)

  def production_is_source_production(
      self, production):
    # If `production` exists in the record and one of its provenances is itself
    # then it's a source production.
    return (production in self.data and
            any(provenance == ProductionProvenance(source=production)
                for provenance in self.data[production]))


def _get_variables(production):
  return [
      item['sem']
      for item in production.rhs()
      if isinstance(item, nltk.grammar.Nonterminal)
  ]


def num_variables_from_production(production):
  """Returns the number of unbound variables in the production.

  The current implementation counts the number of unique variables on the RHS
  of the production.

  Args:
    production: The nltk.grammar.Production whose variables to be counted.
  """
  return len(set(_get_variables(production)))


# This function is called three times for every call to the function compose,
# often with the same arguments, so we memoize it.  Basic profiling runs show
# a reduction of about 35% of time spent on generating one example group.
@functools.lru_cache(maxsize=None)
def _normalize_variable_names(production,
                              token = 'x'):
  """Returns a new production with normalized variable names.

  The returned production has the same phrase structure and semantics as the
  provided production, but with variable names normalized to use the same token
  and increasing ordering of (1-based) index by their appearing order on the
  RHS.

  Example:
  A[sem=(?u1+?u+?v3)] -> B[sem=?u] 'and' C[sem=?v3] 'after' D[sem=?u1]
  becomes
  A[sem=(?x3+?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2] 'after' D[sem=?x3]

  Args:
    production: The nltk.grammar.Production to be normalized.
    token: The choice of normalized variable name token.

  Raises:
    ValueError: If the provided variable name token already appears in the
    production.  This would cause infinite recursion when calling
    substitute_bindings.
  """
  variables = _get_variables(production)
  variable_names = [variable.name for variable in variables]
  if any(
      variable_name.startswith(f'?{token}')
      for variable_name in variable_names):
    raise ValueError(
        f'Variable name {token} already exists in production {production}.')

  new_variables = [
      nltk.Variable(nltk_utils.add_variable_prefix(f'{token}{i}'))
      for i in range(1,
                     len(variables) + 1)
  ]
  bindings = dict(zip(variables, new_variables))

  new_items = []
  for item in production.rhs():
    if isinstance(item, nltk.grammar.Nonterminal):
      new_items.append(item.substitute_bindings(bindings))
    else:
      new_items.append(item)
  new_rhs = tuple(new_items)
  new_lhs = production.lhs().substitute_bindings(bindings)
  new_production = nltk.grammar.Production(lhs=new_lhs, rhs=new_rhs)

  return new_production


def normalize_semantics(
    production):
  """Returns a new production whose LHS "sem" feature is normalized.

  A normalized "sem" feature is any of the following:
  - A string.
  - A Variable.
  - A FeatureValueConcat of Variables and strings.
  - A FeatureValueTuple of strings.

  This is to keep things consistent with how Conceptual SCAN expresses the "sem"
  features using FeatureValueConcat.  This is no longer the case when
  productions are composed.  For example we could end up with "sem" features
  that have a mixture of FeatureValueTuple and FeatureValueConcat, such as:
  A[sem=?x1+(WALK, WALK)], which has a nested FeatureValueTuple (WALK, WALK).

  This function would normalize it to use FeatureValueConcat as:
  A[sem=?x1+WALK+WALK], making it consistent with Conceptual SCAN grammars.

  On the other hand, when all the items in the "sem" feature are strings without
  any variable, we would still get a FeatureValueTuple.  This is caused by nltk
  overriding the type of created instance when there are no variables.

  Example: nltk.featstruct.FeatureValueConcat(values=['a', 'b']) has type
  nltk.featstruct.FeatureValueTuple and is equal to
  nltk.featstruct.FeatureValueTuple(['a', 'b'])

  Args:
    production: The nltk.grammar.Production to be normalized.

  Raises:
    ValueError: If the "sem" feature of the LHS is not Variable,
      FeatureValueTuple, or FeatureValueConcat; or if any item in the "sem"
      feature of the LHS is not Variable or FeatureValueTuple.
  """
  sem = production.lhs()['sem']
  if isinstance(sem, (str, nltk.Variable)):
    new_sem = sem
  elif isinstance(
      sem,
      (nltk.featstruct.FeatureValueTuple, nltk.featstruct.FeatureValueConcat)):
    values = []
    for item in sem:
      if isinstance(item, (str, nltk.Variable)):
        # We skip empty strings to avoid production strings that cannot be
        # successfully loaded back to productions such as the following:
        # "A[sem=(, )] -> 'nothing'"
        # "A[sem=(, NONEMPTY)] -> 'nothing'"
        # "A[sem=(, ?x1)] -> 'nothing'"
        if not item:
          continue
        values.append(item)
      elif isinstance(item, nltk.featstruct.FeatureValueTuple):
        values.extend(item)
      else:
        raise ValueError(f'Item in LHS semantic feature should be a Variable '
                         f'or a FeatureValueTuple, but got {item} of type '
                         f'{type(item)}in {production}.')
    if not values:
      new_sem = ''
    elif len(values) == 1:
      new_sem = values[0]
    else:
      new_sem = nltk.featstruct.FeatureValueConcat(values=values)

  else:
    raise ValueError(f'Semantic feature of LHS must be a Variable, '
                     f'FeatureValueTuple, or FeatureValueConcat, but got {sem} '
                     f'of type {type(sem)} in {production}.')

  new_lhs = nltk.grammar.FeatStructNonterminal({
      'sem': new_sem,
      nltk.grammar.TYPE: production.lhs()[nltk.grammar.TYPE]
  })
  new_production = nltk.Production(lhs=new_lhs, rhs=production.rhs())

  return new_production


def composable_indices(parent,
                       other_parent):
  """Returns the indices of parent's RHS items that can be composed."""
  indices = []
  other_parent_lhs_symbol = other_parent.lhs()[nltk.grammar.TYPE]
  for i, term in enumerate(parent.rhs()):
    if isinstance(term, nltk.grammar.Nonterminal) and term[
        nltk.grammar.TYPE] == other_parent_lhs_symbol:
      indices.append(i)

  return indices


@functools.lru_cache(maxsize=None)
def _calculate_composition(parent,
                           other_parent,
                           index):
  """Returns the composition of the productions at the provided index.

  A production parent can be composed with another production other_parent at
  index i if the i-th item in parent's RHS has the same grammar category as
  other_parent's LHS.

  Args:
    parent: An nltk.grammar.Production.
    other_parent: An nltk.grammar.Production.
    index: The index of parent's RHS item to be expanded with other_parent.

  Raises:
    ValueError: If the productions cannot be composed at the provided index.
  """
  if ((not isinstance(parent.rhs()[index], nltk.grammar.Nonterminal)) or
      (parent.rhs()[index][nltk.grammar.TYPE] !=
       other_parent.lhs()[nltk.grammar.TYPE])):
    raise ValueError(f'Productions not composable at index {index}: '
                     f'parent={parent}, other_parent={other_parent}')

  # Internally we use different variable tokens for the two parents so the call
  # to substitute_bindings does not get into an infinite recursion.
  parent = _normalize_variable_names(parent, token='X')
  other_parent = _normalize_variable_names(other_parent, token='Y')

  new_items = []
  for i, item in enumerate(parent.rhs()):
    if i == index:
      new_items.extend(other_parent.rhs())
    else:
      new_items.append(item)
  new_rhs = tuple(new_items)

  bindings = {parent.rhs()[index]['sem']: other_parent.lhs()['sem']}
  new_lhs = parent.lhs().substitute_bindings(bindings)
  new_production = nltk.grammar.Production(lhs=new_lhs, rhs=new_rhs)
  new_production = normalize_semantics(
      _normalize_variable_names(new_production))

  return new_production


def compose(
    parent,
    other_parent,
    index,
    provenance_by_production = None,
    provenances_by_production = None,
):
  """Returns the composition of the productions and maybe records provenance.

  Args:
    parent: An nltk.grammar.Production.
    other_parent: An nltk.grammar.Production.
    index: The index of parent's RHS item to be expanded with other_parent.
    provenance_by_production: If provided, then one representative provenance of
      the composed production will be recorded in it.
    provenances_by_production: If provided, then all provenances of the composed
      production that could result from this composition will be added to the
      set of any provenances that production might already have.
  """
  composed_production = _calculate_composition(parent, other_parent, index)

  if (provenances_by_production is not None or
      provenance_by_production is not None):
    # We make sure the productions are in the normalized format for the purpose
    # of provenance tracking.
    parent = _normalize_variable_names(
        _normalize_variable_names(parent, token='X'))
    other_parent = _normalize_variable_names(
        _normalize_variable_names(other_parent, token='Y'))

    if provenance_by_production is not None:
      # The provenance of a production is not unique. Record only one provenance
      # for every production.
      # composed_production is already normalized in _calculate_composition.
      parent_provenance = provenance_by_production.setdefault(
          parent, ProductionProvenance(source=parent))
      other_parent_provenance = provenance_by_production.setdefault(
          other_parent, ProductionProvenance(source=other_parent))
      if composed_production not in provenance_by_production:
        provenance_by_production[composed_production] = (
            ProductionProvenance.splice(parent_provenance,
                                        other_parent_provenance, index))

    if provenances_by_production is not None:
      parent_provenances = provenances_by_production.setdefault(
          parent, [ProductionProvenance(source=parent)])
      other_parent_provenances = provenances_by_production.setdefault(
          other_parent, [ProductionProvenance(source=other_parent)])
      parents = itertools.product(parent_provenances, other_parent_provenances)
      for (parent_provenance, other_parent_provenance) in parents:
        provenances_by_production.add_provenance(
            composed_production,
            ProductionProvenance.splice(parent_provenance,
                                        other_parent_provenance, index))

  return composed_production


@dataclasses.dataclass(frozen=True)
class ProductionProvenance:
  """The source production and compositions a production is built from.

  For example, the production p = "C -> 'walk' 'and' 'jump'" can be built from
  s = "C -> S 'and' S", t = "S -> 'walk'", and u = "S -> 'jump'" starting with
  the source production s and composing with t and u, then the provenance of the
  production p is ProductionProvenance(source=s, compositions=[(t, 0), (u, 2)]),
  in the sense that if we carry out the following compositions we would get p:
  compose(compose(s, t, 0), u, 2).

  Note that the provenance of a production is not unique in general: the same
  production p has another provenance by switching the order of elements in
  the compositions field, namely
  ProductionProvenance(source=s, compositions=[(u, 2), (t, 0)]).

  Attributes:
    source: The first parent production of the production.
    compositions: The sequence of (other_parent, index) pairs the production is
      built from.
  """
  source: nltk.grammar.Production
  compositions: Tuple[Tuple[nltk.grammar.Production, int],
                      Ellipsis] = dataclasses.field(default_factory=tuple)

  def get_production(self):
    """Returns the production formed by composing all the productions."""
    production = self.source
    for other_parent, index in self.compositions:
      production = compose(production, other_parent, index)

    return production

  def replace(
      self, old_production,
      new_production,
      provenance_by_production
  ):
    """Returns a new provenance with a production replaced.

    For example, if production_provenance is the provenance:
    ProductionProvenance(source=s, compositions=[(t, 0), (u, 2)]), then
    production_provenance.replace(t, t1) is the provenance:
    ProductionProvenance(source=s, compositions=[(t1, 0), (u, 2)]).

    Args:
      old_production: The production to be replaced.
      new_production: The production replacing old_production.
      provenance_by_production: The ProductionProvenanceMapping to record the
        new provenance in.

    Raises:
      ValueError: If old_production and new_production do not have the same RHS.
    """
    # We support replacing a production only if the RHSs are identical.
    if old_production.rhs() != new_production.rhs():
      raise ValueError(f'Old production and new production do not have the '
                       f'same RHS.  Old production: {old_production}, new '
                       f'production: {new_production}.')

    if self.source == old_production:
      source = new_production
    else:
      source = self.source

    compositions = []
    for other_parent, index in self.compositions:
      if other_parent == old_production:
        new_other_parent = new_production
      else:
        new_other_parent = other_parent
      compositions.append((new_other_parent, index))
    compositions = tuple(compositions)

    # We need to manually record provenance created by a replacement, since the
    # resulting production is not obtained by calling the compose function.
    provenance = ProductionProvenance(source, compositions)
    production_with_provenance = provenance.get_production()
    provenance_by_production[production_with_provenance] = provenance

    return provenance

  def to_json(self):
    """Returns a JSON representation of this object."""
    return {
        'source': str(self.source),
        # We convert the production to a string, as nltk.grammar.Production is
        # not directly JSON serializable. The index is supposed to already be an
        # integer, but we convert it explicitly to an int anyway, as for some
        # reason it sometimes gets populated as an Int64, which is also not JSON
        # serializable.
        'compositions':
            tuple((str(p), int(index)) for p, index in self.compositions)
    }

  @classmethod
  def from_json(cls, unstructured):
    """Returns the original object restored from its JSON representation."""
    return ProductionProvenance(
        source=nltk_utils.production_from_production_string(
            unstructured['source']),
        compositions=tuple(
            (nltk_utils.production_from_production_string(p), index)
            for p, index in unstructured['compositions']))

  @classmethod
  def splice(cls, parent_provenance,
             other_parent_provenance,
             index):
    """Returns the provenance connecting two provenances.

    The purpose of this method is to calculate the provenance of the composition
    of two productions, parent and other_parent, if the parents' provenances are
    known.  This way the composition production's provenance can be expressed in
    terms of the same pieces as the provenances of the parents.

    Specifically, if c = compose(parent, other_parent, index), then
    splice(parent_provenance, other_parent_provenance, index) is a provenance
    of the composition c.

    Args:
      parent_provenance: The provenance of the parent production in a
        composition.
      other_parent_provenance: The provenance of the other_parent production in
        a composition.
      index: The index at which the parent and the other_parent are composed.

    Raises:
      ValueError: If the productions of parent_provenance and
        other_parent_provenance cannot be composed at the index.
    """
    production_with_parent_provenance = parent_provenance.get_production()
    if index not in composable_indices(production_with_parent_provenance,
                                       other_parent_provenance.source):
      raise ValueError(f'Production {production_with_parent_provenance} of '
                       f'parent not composable with source '
                       f'{other_parent_provenance.source} of other parent at '
                       f'index {index}.')

    source = parent_provenance.source
    compositions = list(parent_provenance.compositions)
    compositions.append((other_parent_provenance.source, index))
    for other_parent, other_index in other_parent_provenance.compositions:
      compositions.append((other_parent, other_index + index))
    compositions = tuple(compositions)

    return ProductionProvenance(source, compositions)
