# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for converting a FeatureGrammar to an equivalent GrammarSchema.

The main use case of this is to convert the standard grammars (loaded as
nltk.grammar.FeatureGrammar by the grammar_loader module) to the GrammarSchema
format so they can be used as templates in the grammar_generation module.
"""

from typing import Dict, List
import nltk

from conceptual_learning.cscan import grammar_schema as gs
from conceptual_learning.cscan import nltk_utils


def _category_from_nonterminal(
    nonterminal):
  return nonterminal[nltk.grammar.TYPE]


def _function_arg_from_nonterminal(
    nonterminal):
  function_arg = gs.FunctionArg(
      variable=nonterminal['sem'].name,
      category=_category_from_nonterminal(nonterminal))
  return function_arg


def _output_sequence_from_nonterminal(
    nonterminal):
  """Extracts the output sequence string from a nonterminal's semantics feature.

  The current implementation assumes the nonterminal's `sem` feature is either
  a string or a concatenation of strings and nltk.Variables.

  Args:
    nonterminal: An nltk.grammar.FeatureStructNonterminal with `sem` feature.

  Returns:
    The output sequence list of strings.
  """
  output_sequence = []
  sem = nonterminal['sem']
  if isinstance(sem, str):
    output_sequence.append(sem)
  elif isinstance(sem, nltk.Variable):
    output_sequence.append(sem.name)
  else:
    for item in nonterminal['sem']:
      if isinstance(item, nltk.Variable):
        output_sequence.append(item.name)
      else:
        output_sequence.append(item)
  return output_sequence


def _primitive_mapping_from_production(
    production):
  primitive_mapping = gs.PrimitiveMapping(
      input_sequence=list(production.rhs()),
      output_sequence=_output_sequence_from_nonterminal(production.lhs()),
      category=_category_from_nonterminal(production.lhs()))
  return primitive_mapping


def _pass_through_rule_from_production(
    production):
  pass_through_rule = gs.PassThroughRule(
      category=_category_from_nonterminal(production.lhs()),
      arg=_function_arg_from_nonterminal(production.rhs()[0]))
  return pass_through_rule


def _function_rule_from_production(
    production):
  """Converts the production to a FunctionRule.

  The current implementation assumes the RHS of the production contains exactly
  one nonterminal, which serves as the function phrase.

  Args:
    production: An nltk.grammar.Production.

  Returns:
    A FunctionRule.
  """
  args = []
  function_phrase = None
  num_postfix_args = 0
  for item in production.rhs():
    if isinstance(item, nltk.grammar.Nonterminal):
      args.append(_function_arg_from_nonterminal(item))
      if function_phrase is not None:
        num_postfix_args += 1
    else:
      function_phrase = item

  function_rule = gs.FunctionRule(
      function_phrase=[function_phrase],
      category=_category_from_nonterminal(production.lhs()),
      num_args=len(args),
      num_postfix_args=num_postfix_args,
      args=args,
      output_sequence=_output_sequence_from_nonterminal(production.lhs()))
  return function_rule


def _concat_rule_from_production(
    production):
  concat_rule = gs.ConcatRule(
      category=_category_from_nonterminal(production.lhs()),
      arg1=_function_arg_from_nonterminal(production.rhs()[0]),
      arg2=_function_arg_from_nonterminal(production.rhs()[1]),
      output_sequence=_output_sequence_from_nonterminal(production.lhs()))
  return concat_rule


def _level_by_category_from_feature_grammar(
    feature_grammar):
  """Gets the level of nonterminal symbols.

  The current implementation assumes that the productions of feature_grammar
  are already ordered from the highest level to the lowest.  Productions that
  correspond to PrimitiveMappings are not included.

  Args:
    feature_grammar: An nltk.grammar.FeatureGrammar.

  Returns:
    A dict mapping grammar category symbol to its level.
  """
  level_by_category = {}
  current_level = 1
  for production in reversed(feature_grammar.productions()):
    if nltk_utils.is_primitive_mapping(production):
      continue
    category = _category_from_nonterminal(production.lhs())
    if category in level_by_category:
      continue
    else:
      level_by_category[category] = current_level
      current_level += 1

  return level_by_category


def _validate_equal(feature_grammar,
                    grammar_schema):
  """Verifies the grammar is equivalent to what would be produced by the schema.

  The implementation checks that the two grammars have the same set of
  production strings and the same start symbols.

  Args:
    feature_grammar: An nltk.grammar FeatureGrammar.
    grammar_schema: A GrammarSchema.

  Raises:
    ValueError: If the rule strings are different.
  """
  production_strings_from_grammar = set(
      [str(production) for production in feature_grammar.productions()])
  production_strings_from_schema = set(
      [rule.to_rule_string() for rule in grammar_schema.get_all_rules()])

  if production_strings_from_grammar != production_strings_from_schema:
    raise ValueError(
        f'Grammars do not have the same production strings.'
        f'\nFeatureGrammar production strings: '
        f'{production_strings_from_grammar}'
        f'\nGrammarSchema production strings: {production_strings_from_schema}')

  start_symbol_of_grammar = feature_grammar.start()[nltk.grammar.TYPE]
  start_symbol_of_schema = grammar_schema.get_start_symbol()
  if start_symbol_of_grammar != start_symbol_of_schema:
    raise ValueError(f'Grammars do not have the same start symbol.'
                     f'\nFeatureGrammar start symbol: {start_symbol_of_grammar}'
                     f'\nGrammarSchema start symbol: {start_symbol_of_schema}')


def grammar_schema_from_feature_grammar(
    feature_grammar):
  """Converts a FeatureGrammar to the GrammarSchema format.

  Args:
    feature_grammar: An nltk.grammar.FeatureGrammar.

  Returns:
    A GrammarSchema equivalent to the feature_grammar.

  Raises:
    ValueError: If any of feature_grammar's production is not
    recognized by the is_* functions in this module.
  """
  primitives = []
  functions_by_level = {}
  pass_through_rules = {}
  concat_rule_level = None
  concat_rule = None
  level_by_category = _level_by_category_from_feature_grammar(feature_grammar)

  for production in feature_grammar.productions():
    if nltk_utils.is_primitive_mapping(production):
      primitives.append(_primitive_mapping_from_production(production))
      continue

    level = level_by_category[_category_from_nonterminal(production.lhs())]
    if nltk_utils.is_pass_through_rule(production):
      pass_through_rules[level] = _pass_through_rule_from_production(production)
    elif nltk_utils.is_function_rule(production):
      functions_by_level.setdefault(level, []).append(
          _function_rule_from_production(production))
    elif nltk_utils.is_concat_rule(production):
      concat_rule_level = level
      concat_rule = _concat_rule_from_production(production)
    else:
      raise ValueError(f'Failed to convert production: {production}.')

  grammar_schema = gs.GrammarSchema(primitives, functions_by_level,
                                    pass_through_rules, concat_rule_level,
                                    concat_rule)
  _validate_equal(feature_grammar, grammar_schema)
  return grammar_schema
