# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of conversion from source rules to rule strings.

The source rule is a RuleSchema instance or equivalent FeatureGrammar production
such as "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]", which can be
converted into different rule string formats.  This is to be used in sampling.py
and dataset_generation.py to generate datasets whose explicitly provided rules
are represented in different formats.

Supported rule formats:
- FeatureGrammar production: such as "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and'
C[sem=?x2]".
- Interpretation rule: such as "[x1 and x2] = [x1] [x2]".
"""

import re
from typing import Dict, List, Optional, Sequence

import nltk

from conceptual_learning.cscan import enums
from conceptual_learning.cscan import nltk_utils


def _bracket(content):
  return f'[{content}]'


def interpretation_rule_input_tokens_from_production(
    production):
  """Returns the input tokens corresponding to a FeatureGrammar production.

  The returned string is of the same format used in the interpretation rule
  format, but without any enclosing brackets, and before joining the tokens.
  Args:
    production: A FeatureGrammar production with `sem` features.
  """
  rhs_tokens = nltk_utils.extract_rhs_tokens(production)
  processed_rhs_tokens = []
  for token in rhs_tokens:
    if nltk_utils.is_output_token_a_variable(token):
      processed_rhs_tokens.append(nltk_utils.strip_variable_prefix(token))
    else:
      processed_rhs_tokens.append(token)
  return processed_rhs_tokens


def interpretation_rule_input_string_from_production(
    production):
  """Returns the input string corresponding to a FeatureGrammar production.

  The returned string is of the same format used in the interpretation rule
  format, but without any enclosing brackets.

  E.g. From "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]" we get
  "x1 and x2".

  Args:
    production: A FeatureGrammar production with `sem` features.
  """
  return ' '.join(interpretation_rule_input_tokens_from_production(production))


def _interpretation_rule_string_from_production(
    production):
  """Converts a FeatureGrammar production to its interpretation rule string.

  E.g. From "A[sem=(?x1+?x2)] -> B[sem=?x1] 'and' C[sem=?x2]" we get
  "[x1 and x2] = [x1] [x2]"

  Args:
    production: A FeatureGrammar production with `sem` features.

  Returns:
    The interpretation string.
  """

  lhs_tokens = nltk_utils.extract_lhs_tokens(production)

  processed_lhs_tokens = []
  for token in filter(bool, lhs_tokens):
    if nltk_utils.is_output_token_a_variable(token):
      processed_lhs_tokens.append(
          _bracket(nltk_utils.strip_variable_prefix(token)))
    else:
      processed_lhs_tokens.append(token)
  if not processed_lhs_tokens:
    processed_lhs_tokens = ['EMPTY_STRING']
  lhs_string = ' '.join(processed_lhs_tokens)

  rhs_string = _bracket(
      interpretation_rule_input_string_from_production(production))

  separator = separators[enums.RuleFormat.INTERPRETATION_RULE]
  # Here "lhs" and "rhs" refer to the position in the source rule production.
  # In the interpretation rule format, left and right are switched.
  return f'{rhs_string}{separator}{lhs_string}'


def _interpretation_rule_mapping_from_grammar(
    grammar
):
  """Returns the mapping of interpretation rule strings to source productions.

  At the individual production level, from "A[sem=(?x1+?x2)] -> B[sem=?x1]
  'and' C[sem=?x2]" we get "[x1 and x2] = [x1] [x2]".
  PassThroughRules such as "A[sem=?x1] -> B[sem=?x1]" are excluded from the
  returned mapping, since the identity rule [x1] = [x1] should not be used as
  an explicit rule.

  Args:
    grammar: An nltk.grammar.FeatureGrammar.

  Returns:
    The dictionary mapping interpretation rule strings to source productions.
  """
  mapping = {}
  for production in grammar.productions():
    if nltk_utils.is_pass_through_rule(production):
      continue
    else:
      rule = _interpretation_rule_string_from_production(production)
      mapping.setdefault(rule, []).append(production)
  return mapping


def _feature_grammar_production_mapping_from_grammar(
    grammar
):
  """Returns the mapping of production strings to source productions.

  Args:
    grammar: An nltk.grammar.FeatureGrammar.

  Returns:
    The dictionary mapping FeatureGrammar production strings to source
    productions.
  """
  mapping = {}
  for production in grammar.productions():
    rule = str(production)
    mapping.setdefault(rule, []).append(production)
  return mapping


def _merge_consecutive_terminals(tokens):
  """Returns the sequence of tokens with consecutive terminals merged.

  Empty strings are ignored.

  Examples:
    ['a', 'b'] -> ['a b']
    ['a', 'b', '?x1'] -> ['a b', '?x1']
    ['', 'b'] -> ['b']

  Args:
    tokens: The sequence of tokens to process.
  """
  result = []
  terminals_to_merge = []
  for token in tokens:
    if nltk_utils.is_output_token_a_variable(token):
      if terminals_to_merge:
        result.append(' '.join(terminals_to_merge))
        terminals_to_merge = []
      result.append(token)
    elif token:
      # We do not collect token if it is an empty string.
      terminals_to_merge.append(token)

  if terminals_to_merge:
    result.append(' '.join(terminals_to_merge))

  return result


def _process_repeated_variables(tokens):
  """Returns the sequence of tokens with repeated variables replaced.

  Examples:
    ['?x1', '?x1', 'b'] -> ['?x1 repeated twice', 'b']
    ['?x1', '?x1', '?x1', 'b'] -> ['?x1 repeated thrice', 'b']

  Args:
    tokens: The sequence of tokens to process.
  """

  def _repeat_phrase(repeat_count):
    if repeat_count == 2:
      return 'repeated twice'
    elif repeat_count == 3:
      return 'repeated thrice'
    else:
      return f'repeated {repeat_count} times'

  result = []

  def _maybe_add_seen_variable(seen_variable, repeat_count):
    if not seen_variable:
      return
    if repeat_count <= 1:
      result.append(seen_variable)
    else:
      repeated_variable = f'{seen_variable} {_repeat_phrase(repeat_count)}'
      result.append(repeated_variable)

  seen_variable = ''
  repeat_count = 0
  for token in tokens:
    if nltk_utils.is_output_token_a_variable(token):
      if seen_variable == token:
        repeat_count += 1
      else:
        _maybe_add_seen_variable(seen_variable, repeat_count)
        seen_variable = token
        repeat_count = 1
    else:
      _maybe_add_seen_variable(seen_variable, repeat_count)
      seen_variable = ''
      repeat_count = 0
      result.append(token)

  _maybe_add_seen_variable(seen_variable, repeat_count)

  return list(filter(bool, result))


def _a_phrase(name):
  return f'a phrase {name}'


def _the_interpretation_of(name):
  return f'the interpretation of {name}'


def _quote(content):
  return f'"{content}"'


def _sentence_from_string(string, end = '.'):
  # Since the string often has double quotes in it, which interfere with
  # string.capitalize(), we split and rejoin to make sure only the first word
  # is capitalized.
  words = string.split(' ')
  words[0] = words[0].capitalize()
  sentence = ' '.join(words)
  return f'{sentence}{end}'


def _natural_language_string_from_production(
    production):
  """Converts a FeatureGrammar production to natural language representation.

  E.g. From "X[sem=(?x1+?x2)] -> Y[sem=?x1] 'and' Z[sem=?x2]" we get
  'the interpretation of a phrase x1 followed by "and" followed by a phrase x2
  is the interpretation of x1 followed by the interpretation of x2'.

  Args:
    production: A FeatureGrammar production with `sem` features.

  Returns:
    The natural language string.
  """
  followed_by = 'followed by'

  lhs_tokens = nltk_utils.extract_lhs_tokens(production)
  rhs_tokens = nltk_utils.extract_rhs_tokens(production)

  lhs_tokens = _merge_consecutive_terminals(lhs_tokens)
  rhs_tokens = _merge_consecutive_terminals(rhs_tokens)

  lhs_tokens = _process_repeated_variables(lhs_tokens)

  processed_lhs_tokens = []
  for token in filter(bool, lhs_tokens):
    if nltk_utils.is_output_token_a_variable(token):
      processed_lhs_tokens.append(
          _the_interpretation_of(nltk_utils.strip_variable_prefix(token)))
    else:
      processed_lhs_tokens.append(_quote(token))
  if not processed_lhs_tokens:
    processed_lhs_tokens = ['the empty string']
  lhs_string = f' {followed_by} '.join(processed_lhs_tokens)

  processed_rhs_tokens = []
  for token in rhs_tokens:
    if nltk_utils.is_output_token_a_variable(token):
      processed_rhs_tokens.append(
          _a_phrase(nltk_utils.strip_variable_prefix(token)))
    else:
      processed_rhs_tokens.append(_quote(token))
  rhs_string = _the_interpretation_of(
      f' {followed_by} '.join(processed_rhs_tokens))

  separator = separators[enums.RuleFormat.NATURAL_LANGUAGE]
  # Here "lhs" and "rhs" refer to the position in the source rule production.
  # In the natural language format, left and right are switched.
  string = f'{rhs_string}{separator}{lhs_string}'
  return _sentence_from_string(string)


def natural_language_non_rule_request(request):
  quoted_request = _quote(request)
  string = _the_interpretation_of(quoted_request)
  question = f'what is {string}'
  return _sentence_from_string(question, end='?')


def _natural_language_mapping_from_grammar(
    grammar
):
  """Returns the mapping of natural language strings to source productions.

  PassThroughRules are excluded from the returned mapping, since the rule "the
  interpretation of a phrase x1 is the interpretation of x1" is not useful.

  Args:
    grammar: An nltk.grammar.FeatureGrammar.

  Returns:
    The dictionary mapping natural language representation strings to source
    productions.
  """
  mapping = {}
  for production in grammar.productions():
    if nltk_utils.is_pass_through_rule(production):
      continue
    else:
      rule = _natural_language_string_from_production(production)
      mapping.setdefault(rule, []).append(production)
  return mapping


def rule_mapping_from_grammar(
    grammar,
    rule_format):
  """Returns the mapping of rule strings to source productions.

  Every rule string is mapped to its list of source productions.

  Args:
    grammar: An nltk.grammar.FeatureGrammar.
    rule_format: The format of the output rule strings.

  Returns:
    The dictionary mapping converted rule strings to source production.
  """
  if rule_format == enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION:
    return _feature_grammar_production_mapping_from_grammar(grammar)
  elif rule_format == enums.RuleFormat.INTERPRETATION_RULE:
    return _interpretation_rule_mapping_from_grammar(grammar)
  elif rule_format == enums.RuleFormat.NATURAL_LANGUAGE:
    return _natural_language_mapping_from_grammar(grammar)
  else:
    raise ValueError(f'Conversion to rule format {rule_format} not supported.')


def rule_from_production(production,
                         rule_format):
  """Converts the production to a rule string of the given format.

  Args:
    production: The nltk.grammar.Production to be converted.
    rule_format: The format of the output rule strings.

  Returns:
    The rule string of the production. Returns None if the production does
    not have a rule string (for example, PassThroughRule in the interpretation
    rule format).
  """
  if rule_format == enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION:
    return str(production)
  elif rule_format == enums.RuleFormat.INTERPRETATION_RULE:
    if nltk_utils.is_pass_through_rule(production):
      return None
    else:
      return _interpretation_rule_string_from_production(production)
  elif rule_format == enums.RuleFormat.NATURAL_LANGUAGE:
    if nltk_utils.is_pass_through_rule(production):
      return None
    else:
      return _natural_language_string_from_production(production)
  else:
    raise ValueError(f'Conversion to rule format {rule_format} not supported.')


def _input_phrase_from_rule(rule):
  """Returns the input phrase of the interpretation rule string."""
  pattern = r'\[([^\[\]]+)\] = .+'
  match = re.fullmatch(pattern, rule)
  if match is None:
    raise ValueError(f'Incorrect rule format: {rule}')
  return match.group(1)


def rule_pattern_from_rule(rule):
  """Returns the rule pattern of the interpretation rule string.

  Examples:
    "[x1 twice] = [x1] [x1]" -> "[x1 _] = [x1] [x1]"
    "[x1 left] = [x1] [x1]" -> "[x1 _] = [x1] [x1]"
    "[x1 and x2] = [x1] [x2]" -> "[x1 _ x2] = [x1] [x2]"

  The same placeholder token "_" is used for all non-variable tokens.

  The current implementation assumes the rule has been normalized, in that all
  the variable tokens match the pattern /x[0-9]+/, and all the non-variable
  tokens do not match it.

  Args:
    rule: The interpretation rule string to be converted into rule pattern.
  """
  input_phrase = _input_phrase_from_rule(rule)
  input_tokens = input_phrase.split()

  variable_pattern = r'x\d+'
  updated_tokens = []
  for token in input_tokens:
    if re.fullmatch(variable_pattern, token):
      updated_tokens.append(token)
    else:
      updated_tokens.append('_')
  updated_phrase = ' '.join(updated_tokens)
  return rule.replace(input_phrase, updated_phrase)


# Rule format-specific information that we currently need to expose to
# data_generation.py.
separators: Dict[enums.RuleFormat, str] = {
    enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION: ' -> ',
    enums.RuleFormat.INTERPRETATION_RULE: ' = ',
    enums.RuleFormat.NATURAL_LANGUAGE: ' is '
}
