# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for interacting with nltk objects as used in cSCAN.

This is a low-level library, at the bottom of the dependency hierarchy of the
cSCAN codebase. In order to avoid circular dependencies, it is important that
this library does not import any other cSCAN-related libraries.
"""

import functools
from typing import Optional, Sequence

import nltk


@functools.lru_cache(maxsize=None)
def extract_lhs_tokens(
    production):
  """Returns strings representing items from the LHS of the production.

  Items that are strings (terminals) are collected as-is, while nltk.Variables
  have their names extracted.

  Args:
    production: A FeatureGrammar production with `sem` features.

  Raises:
    ValueError: If the feature value of sem on LHS is not a string, an
      nltk.Variable, an nltk.featstruct.FeatureValueConcat, or an
      nltk.featstruct.FeatureValueTuple, or if there is a component in the
      feature value that is not a string or a Variable.
  """
  # Examples generated by dataset_generation.py have their metadata.production
  # populated with a proper production that is not None, but we include this
  # here just so that unit tests do not have to have dummy productions in test
  # cases' example metadata.
  if production is None:
    return ()

  items = []
  sem = production.lhs()['sem']
  if isinstance(sem, str):
    items.append(sem)
  elif isinstance(sem, nltk.Variable):
    items.append(sem.name)
  elif isinstance(
      sem,
      (nltk.featstruct.FeatureValueConcat, nltk.featstruct.FeatureValueTuple)):
    for term in sem:
      if isinstance(term, str):
        items.append(term)
      elif isinstance(term, nltk.Variable):
        items.append(term.name)
      else:
        raise ValueError(
            f'Failed to extract item from {sem} of type {type(sem)}.')
  else:
    raise ValueError(f'Failed to extract item from {sem} of type {type(sem)}.')

  return items


@functools.lru_cache(maxsize=None)
def extract_rhs_tokens(
    production):
  """Returns strings representing items from the RHS of the production.

  Items that are strings (terminals) are collected as-is, while nltk.Variables
  have their names extracted.

  Args:
    production: A FeatureGrammar production with `sem` features.

  Raises:
    ValueError: If the feature value of sem of an item on RHS is not a string or
    an nltk.Variable.
  """
  # Examples generated by dataset_generation.py have their metadata.production
  # populated with a proper production that is not None, but we include this
  # here just so that unit tests do not have to have dummy productions in test
  # cases' example metadata.
  if production is None:
    return ()

  items = []
  for term in production.rhs():
    if isinstance(term, nltk.grammar.FeatStructNonterminal):
      sem = term['sem']
      if not isinstance(sem, nltk.Variable):
        raise ValueError(
            'Nonterminal term on the RHS should have only nltk.Variable as '
            f'semantics, but got {term}.')
      items.append(sem.name)
    else:
      items.append(term)

  return items


def production_from_production_string(
    production_string):
  """Returns a production recovered from its string representation."""
  return nltk.grammar.FeatureGrammar.fromstring(
      production_string).productions()[0]


def is_primitive_mapping(production):
  """Checks if the production is a PrimitiveMapping's rule string.

  Args:
    production: An nltk.grammar.Production.

  Returns:
    True if the production is a PrimitiveMapping.
  """
  return all(not isinstance(item, nltk.grammar.Nonterminal)
             for item in production.rhs())


def is_pass_through_rule(production):
  """Checks if the production is a PassThroughRule's rule string.

  Args:
    production: An nltk.grammar.Production.

  Returns:
    True if the production is a PassThroughRule.
  """
  return (len(production.rhs()) == 1 and
          isinstance(production.rhs()[0], nltk.grammar.Nonterminal))


def is_function_rule(production):
  """Checks if the production is a FunctionRule's rule string.

  Args:
    production: An nltk.grammar.Production.

  Returns:
    True if the production is a FunctionRule.
  """
  rhs_nonterminals = [
      term for term in production.rhs()
      if not isinstance(term, nltk.grammar.Nonterminal)
  ]
  return len(production.rhs()) > 1 and len(rhs_nonterminals) == 1


def is_concat_rule(production):
  """Checks if the production is a ConcatRule's rule string.

  Args:
    production: An nltk.grammar.Production.

  Returns:
    True if the production is a ConcatRule.
  """
  return (len(production.rhs()) == 2 and all(
      isinstance(term, nltk.grammar.Nonterminal) for term in production.rhs()))


def is_output_token_a_variable(token):
  """Returns whether the given output token represents a variable."""
  return token.startswith('?')


def add_variable_prefix(token):
  """Returns the token preprended with a prefix indicating it is a variable."""
  return f'?{token}'


def strip_variable_prefix(token):
  """Returns a shortened variable name, minus the variable-indicating prefix."""
  return token.replace('?', '')


@functools.lru_cache(maxsize=None)
def output_pattern_from_production(production):
  """Returns the pattern of the production's sem feature.

  This function normalizes the variable token indices. For example, the pattern
  of "[x3] LTURN [x3] WALK WALK LTURN [x1]" is: "x1 o1 x1 o2 o2 o1 x2".

  Args:
    production: The production to compute output pattern for.
  """
  output_pattern_parts = []
  placeholder_by_non_terminal = {}
  placeholder_by_terminal = {}
  for output_token in extract_lhs_tokens(production):
    if is_output_token_a_variable(output_token):
      placeholder = placeholder_by_non_terminal.setdefault(
          output_token, f'x{len(placeholder_by_non_terminal) + 1}')
    else:
      placeholder = placeholder_by_terminal.setdefault(
          output_token, f'o{len(placeholder_by_terminal) + 1}')

    output_pattern_parts.append(placeholder)

  return ' '.join(output_pattern_parts)


def grammar_to_string(grammar):
  """Formats grammar as a string parsable using FeatureGrammar.fromstring."""
  # Extract the start state.
  start = grammar.start()
  start_type = start[nltk.grammar.TYPE]

  # Format the start state string in the format `% TYPE[sem=SEM_FEATURES]` if
  # the start state has a `sem` feature, else format it to `% TYPE`.
  if 'sem' in start:
    sem_features_string = str(start['sem'])
    start_string = f'{start_type}[sem={sem_features_string}]'
  else:
    start_string = f'{start_type}'
  grammar_string = f'% start {start_string}\n'
  # Format the productions as new-line separated strings.
  grammar_string += '\n'.join([str(p) for p in grammar.productions()])
  return grammar_string
