# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for representing SCAN-like grammars in an easily manipulable way."""

import abc
import collections
import dataclasses
import itertools
import logging
from typing import (AbstractSet, Dict, Iterable, Iterator, List, Mapping,
                    Optional, Sequence)

import immutabledict
import numpy as np

from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import stats_utils


@dataclasses.dataclass(order=True)
class FunctionArg:
  """Schema of one argument of a variable-based rule.

  (See FunctionRule below.)

  Attributes:
    variable: E.g., for arg 'V[sem=?x1]', the variable would be '?x1'.
    category: E.g., for arg 'V[sem=?x1]', the category would be 'V'.
  """
  variable: Optional[str] = None
  category: Optional[str] = None

  def to_string(self):
    """Returns a string for use in rhs of a FeatureGrammar rule."""
    return f'{self.category}[sem={self.variable}]'

  def validate(self):
    """Asserts that the arg is fully-specified and valid."""
    if not self.variable:
      raise ValueError('Arg variable not specified')
    if not self.category:
      raise ValueError('Arg category not specified')


@dataclasses.dataclass
class RuleSchema(metaclass=abc.ABCMeta):
  """Interface and shared functionality for rules in a GrammarSchema.

  Example:
    Rule: V[sem=(...)] -> ...
      For a rule of the above form, the category would be 'V'.
    For examples of args, output_sequence, etc., see the various sub-classes
    of RuleSchema (FunctionRule, ConcatRule, etc.) below.

  Attributes:
    category: Syntactic category output by this rule.
  """

  category: Optional[str] = None

  def get_args(self):
    """Returns a sequence of args (i.e., RHS terms containing variables).

    Should be overridden for rules that actually take args.
    """
    return []

  def get_arg_categories(self):
    categories = set()
    for arg in self.get_args():
      if arg.category:
        categories.add(arg.category)
    return categories

  @abc.abstractmethod
  def get_output_sequence(self):
    """Returns the output tokens (each a variable or string literal)."""

  @abc.abstractmethod
  def get_rhs_terms(self):
    """Returns a sequence of terms to appear in the rule's righthand side."""

  def _get_output_sequence_string(self):
    """Returns the part of the rule string representing the output sequence."""
    output_terms = self.get_output_sequence()
    if not output_terms:
      return repr('')
    elif len(output_terms) == 1:
      return output_terms[0] if nltk_utils.is_output_token_a_variable(
          output_terms[0]) else repr(output_terms[0])
    else:
      return f'({"+".join(output_terms)})'

  def to_rule_string(self):
    """Returns the corresponding rule in FeatureGrammar string format."""
    rhs_string = ' '.join(self.get_rhs_terms())
    return (f'{self.category}[sem={self._get_output_sequence_string()}] -> '
            f'{rhs_string}')

  def validate(self, options):
    """Asserts that the rule is fully-specified and valid.

    Should normally be overridden by the sub-class.

    Args:
      options: Options with respect to which to validate.

    Raises:
      ValueError: If any issue was found.
    """
    # Validate category
    if not self.category:
      raise ValueError(f'Rule lacking category: {self}')
    # Validate variables
    for arg in self.get_args():
      arg.validate()
      if (options.level_by_category.get(arg.category) >=
          options.level_by_category.get(self.category)):
        raise ValueError('Arg level must be smaller than function level')
    vars_in_output = {
        token for token in self.get_output_sequence()
        if nltk_utils.is_output_token_a_variable(token)
    }
    vars_in_rhs = {arg.variable for arg in self.get_args()}
    if len(vars_in_rhs) < len(self.get_args()):
      raise ValueError('Duplicate variable in function args')
    if vars_in_output.symmetric_difference(vars_in_rhs):
      raise ValueError('Vars in output do not match those in rhs')


@dataclasses.dataclass
class RuleWithExplicitOutputSequence(RuleSchema):
  """Superclass for rules with an output sequence that can be explicitly set.

  Attributes:
    output_sequence: List of output tokens (each a variable or string literal).
  """
  output_sequence: List[str] = dataclasses.field(default_factory=list)

  def get_output_sequence(self):
    """See parent class."""
    return self.output_sequence


@dataclasses.dataclass
class PrimitiveMapping(RuleWithExplicitOutputSequence):
  """Schema of a rule mapping input phrase to output phrase, with no variables.

  The input_sequence and output_sequence for such rules should consist of string
  literals only.

  Example:
    Rule: U[sem=('WALK')] -> 'walk'
    Here, the input tokens would be ['walk'], the output tokens would be
    ['WALK'], and the category would be 'U'.

  Attributes:
    input_sequence: Sequence of input tokens that should trigger this rule.
      E.g., in the above example, this would be ['walk'].
  """
  input_sequence: List[str] = dataclasses.field(default_factory=list)

  def get_rhs_terms(self):
    """See parent class."""
    return [repr(token) for token in self.input_sequence]

  def validate(self, options):
    """See parent class."""
    super().validate(options)
    if not self.input_sequence:
      raise ValueError('Input sequence empty')


@dataclasses.dataclass
class FunctionRule(RuleWithExplicitOutputSequence):
  """Schema of a rule defining the behavior of a function phrase like 'twice'.

  Example:
    Rule: S[sem=(?x1 + ?x1)] -> V[sem=?x1] 'twice'"
    Here, the function phrase would be ['twice']. It would take one prefix
    argument (and no postfix arguments). The function category is 'S'. The first
    argument's category is 'V'. The first argument's variable is '?x1'. The
    output sequence is ['?x1', '?x1'].

  Note that num_args is represented as an explicit attribute in order to allow
  use of a partially-specified FunctionRule containing just the number of
  arguments but not the arguments themselves, for example as part of a
  GrammarSchema template. In the case of a fully-specified FunctionRule,
  num_args is expected to be exactly equal to len(args).

  Attributes:
    function_phrase: List of input tokens representing the name of the function.
      E.g., in the above example, this would be ['twice'].
    num_args: Number of arguments that the rule is intended to have. If args has
      already been populated, then this should equal len(args). E.g., in the
      above example, this would be 1.
    num_postfix_args: Number of arguments that come after the function phrase.
      E.g., in the above example, this would be 0.
    num_prefix_args: Number of arguments that come before the function phrase.
      E.g., in the above example, this would be 1.
    args: Actual list of arguments. E.g., in the above example, this would be
      [FunctionArg(category='V', variable='?x1')]
  """
  function_phrase: List[str] = dataclasses.field(default_factory=list)
  num_args: Optional[int] = None
  num_postfix_args: Optional[int] = None
  args: List[FunctionArg] = dataclasses.field(default_factory=list)

  @property
  def num_prefix_args(self):
    if self.num_args is None or self.num_postfix_args is None:
      return None
    else:
      return max(0, self.num_args - self.num_postfix_args)

  def get_function_phrase_string(self):
    """Returns the function phrase formatted as a single string.

      E.g., if the function phrase were ['two', 'times'], this would be
      'two times'.
    """
    return ' '.join(self.function_phrase)

  def get_args(self):
    """See parent class."""
    return self.args

  def get_rhs_terms(self):
    """See parent class."""
    # To avoid an index-out-of-bounds error in the case of an incompletely or
    # incorrectly populated FunctionRule, we cap the requested number of prefix
    # (and indirectly, postfix args) here if they would otherwise exceed the
    # size of the actual argument list. This is preferable to simply raising an
    # error, as it allows the use of methods building on get_rhs_terms() in the
    # testing and debugging of such problematic GrammarSchemas, which leads to
    # more intuitive error messages and test output. The actual raising of
    # errors is left to the validate() method.
    num_prefix_args = min(self.num_prefix_args or 0, len(self.args))

    rhs_terms = [arg.to_string() for arg in self.args[:num_prefix_args]]
    if self.function_phrase:
      rhs_terms.extend(repr(token) for token in self.function_phrase)
    rhs_terms.extend(arg.to_string() for arg in self.args[num_prefix_args:])
    return rhs_terms

  def validate(self, options):
    """See parent class."""
    super().validate(options)
    if not self.args:
      raise ValueError('Arg list empty')
    if not self.function_phrase:
      raise ValueError('Empty function phrase')
    if self.num_args is None:
      raise ValueError('Num_args not specified')
    if self.num_postfix_args is None:
      raise ValueError('Num_postfix_args not specified')
    if len(self.args) != self.num_args:
      raise ValueError('Mismatch in number of args')
    if self.num_postfix_args < 0:
      raise ValueError('Num_postfix_args cannot be negative')
    if self.num_postfix_args > self.num_args:
      raise ValueError('Num_postfix_args cannot exceed num_args')


@dataclasses.dataclass
class PassThroughRule(RuleSchema):
  """Schema of a rule that simply converts one category to another.

  The righthand category must be of the directly previous precedence level to
  the lefthand category.

  Example:
    Rule: D[sem=?x1] -> U[sem=?x1]

  Attribute:
    arg: E.g., in the above example, this would be
      FunctionArg(category='U',variable='x1')
  """
  arg: FunctionArg = dataclasses.field(default_factory=FunctionArg)

  def get_args(self):
    """See parent class."""
    return [self.arg]

  def get_output_sequence(self):
    """See parent class."""
    return [self.arg.variable] if self.arg.variable else []

  def get_rhs_terms(self):
    """See parent class."""
    return [self.arg.to_string()]

  def validate(self, options):
    """See parent class."""
    super().validate(options)
    # Validate arg levels
    for arg in self.get_args():
      if (options.level_by_category.get(arg.category) !=
          options.level_by_category.get(self.category) - 1):
        raise ValueError(
            f'PassThroughRule arg must be from exactly one level below: {arg}')


@dataclasses.dataclass
class ConcatRule(RuleWithExplicitOutputSequence):
  """Schema of a rule that simply concatenates two terms together.

  The categories of the two righthand terms must be of the directly previous
  precedence level to the lefthand category.

  Example:
    Rule: D[sem=(?x2+?x1)] -> U[sem=?x1] W[sem=?x2]

  Attributes:
    arg1: E.g., in the above example, this would be
      FunctionArg(category='U',variable='x1')
    arg2: E.g., in the above example, this would be
      FunctionArg(category='W',variable='x2')
  """
  arg1: FunctionArg = dataclasses.field(default_factory=FunctionArg)
  arg2: FunctionArg = dataclasses.field(default_factory=FunctionArg)

  def get_args(self):
    """See parent class."""
    return [self.arg1, self.arg2]

  def get_rhs_terms(self):
    """See parent class."""
    return [self.arg1.to_string(), self.arg2.to_string()]

  def validate(self, options):
    """See parent class."""
    super().validate(options)

    # Validate arg levels if requested by the corresonding flag.
    if not options.validate_concat_rule_level:
      return
    for arg in self.get_args():
      if (options.level_by_category.get(arg.category) !=
          options.level_by_category.get(self.category) - 1):
        raise ValueError(
            f'ConcatRule arg must be from exactly one level below: {arg}')


@dataclasses.dataclass
class GrammarSchema:
  """Schema of a SCAN-like grammar in a format convenient for generation.

  If all fields are specified, then the schema uniquely describes a
  FeatureGrammar to be generated with the corresponding contents.

  A GrammarSchema with one or more fields left unspecified can also be provided
  to the GrammarGenerator as a partial specification of the grammar to be
  generated. For this purpose, a field is considered to be unspecified if its
  value is None or an empty List or an empty Dict. The GrammarGenerator would
  then automatically populate any such unspecified fields -- either in the
  GrammarSchema itself or in any of the objects contained therein -- with values
  generated randomly in accordance with the GrammarOptions. (See below for a
  few special cases where a different criterion is used to signal to the
  GrammarGenerator to auto-populate the values.)

  Note that to reduce the chance of ambiguity in the grammar, functions are
  organized into strict precedence levels (with lower level meaning tighter
  binding), and the following restrictions are applied to their interactions:
  - each function outputs a category specific to that precedence level
  - functions are only allowed to take as arguments categories associated with
    lower precedence levels.
  One effect of the above restrictions is that the generated grammars are always
  finite and are never recursive.

  Level 0 represents the rules that implement the primitive mappings.

  Each precedence level can optionally be linked to its previous level through
  a pass-through rule of the form
    'CurrLevelCategory = PrevLevelCategory'
  Further, there can be at most one concatenation rule of the form
    'CurrLevelCategory = PrevLevelCategory1 PrevLevelCategory2'
  Concatenation rules are limited to at most one globally, rather than one per
  precedence level, again to reduce the chance of ambiguous parses.

  Attributes:
    primitives: List of primitive mapping rules.
    functions_by_level: Mapping of precedence level to FunctionRules associated
      with that level.
    pass_through_rules: Mapping of precedence level to the PassThroughRule
      associated with that level (if any). In the context of a template
      GrammarSchema, pass_through_rules is considered to be eligible for
      auto-population if and only if functions_by_level is.
    concat_rule_level: The precedence level at which the ConcatRule is intended
      to exist (if any). If concat_rule is already fully specified, then
      concat_rule_level should match the concat_rule's category_level. In the
      context of a template GrammarSchema, concat_rule_level is considered to be
      eligible for auto-population if and only if it is None and
      functions_by_level is also eligible for auto-population.
    concat_rule: The ConcatRule (if any). In the context of a template
      GrammarSchema, concat_rule is considered to be eligible for
      auto-population if and only if it is None and concat_rule_level is
      specified (or has been auto-generated).
  """
  primitives: List[PrimitiveMapping] = dataclasses.field(default_factory=list)
  functions_by_level: Dict[int, List[FunctionRule]] = dataclasses.field(
      default_factory=dict)
  pass_through_rules: Dict[int, PassThroughRule] = dataclasses.field(
      default_factory=dict)
  concat_rule_level: Optional[int] = None
  concat_rule: Optional[ConcatRule] = None

  def sample_rules(self, options,
                   rng):
    """Samples rules according to the provided sampling options.

    The rules that are not sampled are removed.

    Args:
      options: The options that specify how the rules should be sampled.
      rng: RandomState used for randomly selecting which rules should be
        sampled.
    """
    if options.num_rules < 0:
      return

    # Collect all actual_rules (i.e., ignore PassThroughRules) and their
    # categories.
    actual_rules = list()
    categories = set()
    for rule in self.get_all_rules():
      if not isinstance(rule, PassThroughRule):
        actual_rules.append(rule)
        categories.add(rule.category)
    rng.shuffle(actual_rules)

    # Make sure that the front of the actual_rules list contains one rule for
    # each category.
    min_sample_count = len(categories)
    for i in range(len(actual_rules)):
      rule = actual_rules[i]
      if rule.category in categories:
        categories.remove(rule.category)
        actual_rules.insert(0, actual_rules.pop(i))

    # Make sure that the flags are compatible with the grammar.
    if options.num_rules_min < min_sample_count:
      raise ValueError(
          f'Minimum number of rules {options.num_rules_min} must be '
          f'large enough to sample a rule for each category {min_sample_count}')
    if options.num_rules_max > len(actual_rules):
      raise ValueError(
          f'Maximum number of rules {options.num_rules_max} must not be larger '
          f'than the number of available rules {len(actual_rules)}')

    # Sample rules and collect their ids.
    sample_count = round(
        stats_utils.sample_clipped_truncated_normal(
            left=options.num_rules_min,
            right=options.num_rules_max,
            mean=options.num_rules,
            std=options.num_rules_stddev,
            rng=rng))
    sampled_rules = actual_rules[0:sample_count]
    sampled_rule_ids = set(map(id, sampled_rules))
    logging.info('Sampling %d rules: %s', sample_count, sampled_rules)

    # Remove all rules that are not sampled and clean up the rule map by
    # removing empty levels.
    def remove_not_sampled(rules):
      rules[:] = filter(lambda x: id(x) in sampled_rule_ids, rules)

    if id(self.concat_rule) not in sampled_rule_ids:
      self.concat_rule = None
    remove_not_sampled(self.primitives)
    levels_to_remove = set()
    for level, rules in self.functions_by_level.items():
      remove_not_sampled(rules)
      if not rules:
        levels_to_remove.add(level)
    for level in levels_to_remove:
      del self.functions_by_level[level]

  def get_max_level(self):
    """Returns the maximum precedence level for which any rule is defined."""
    return max(
        max(self.functions_by_level, default=0),
        max(self.pass_through_rules, default=0), self.concat_rule_level or 0)

  def get_all_rules(self):
    """Yields an iterator over all rules in the grammar."""
    all_rules = itertools.chain(
        self.primitives,
        itertools.chain.from_iterable(self.functions_by_level.values()),
        self.pass_through_rules.values())
    if self.concat_rule is not None:
      all_rules = itertools.chain(all_rules, [self.concat_rule])
    return all_rules

  def _get_rules_by_level(self):
    """Returns a mapping of precedence level to list of rule strings."""
    rules_by_level = collections.defaultdict(list)
    for rule in self.primitives:
      rules_by_level[0].append(rule)
    for level, rules in self.functions_by_level.items():
      for rule in rules:
        rules_by_level[level].append(rule)
    if self.concat_rule is not None and self.concat_rule_level is not None:
      rules_by_level[self.concat_rule_level].append(self.concat_rule)
    for level, rule in self.pass_through_rules.items():
      rules_by_level[level].append(rule)
    return immutabledict.immutabledict(rules_by_level)

  def _get_rule_strings_by_level(self):
    """Returns a mapping of precedence level to list of rule strings."""
    rule_strings_by_level = collections.defaultdict(list)
    for level, rules in self._get_rules_by_level().items():
      for rule in rules:
        rule_strings_by_level[level].append(rule.to_rule_string())
    return immutabledict.immutabledict(rule_strings_by_level)

  def get_rules_for_level(self, level):
    """Returns all rules defined at that level."""
    return self._get_rules_by_level().get(level, [])

  def get_args_for_level(self, level):
    """Returns all args of all rules defined at that level."""
    args = []
    for rule in self._get_rules_by_level().get(level, []):
      args.extend(rule.get_args())
    return args

  def get_rule_categories_for_level(self, level):
    """Returns the categories for which rules are defined at that level."""
    return frozenset(
        rule.category
        for rule in self.get_rules_for_level(level)
        if rule.category)

  def get_all_categories_for_level(
      self, level, options):
    """Returns all categories of that level that are used in any way.

    Args:
      level: The level for which to return the relevant categories.
      options: Options for determining the mapping between category and level.
    """
    categories = set(self.get_rule_categories_for_level(level))
    for rule in self.get_all_rules():
      for arg in rule.get_args():
        if options.level_by_category.get(arg.category) == level:
          categories.add(arg.category)
    return categories

  def _get_rules_by_category(self):
    """Returns a mapping of category to rules that output it."""
    rules_by_category = collections.defaultdict(list)
    for rule in self.get_all_rules():
      if rule.category:
        rules_by_category[rule.category].append(rule)
    return immutabledict.immutabledict(rules_by_category)

  def _get_rules_by_arg_category(self):
    """Returns a mapping of category to rules that consume it."""
    rules_by_arg_category = collections.defaultdict(list)
    for rule in self.get_all_rules():
      for category in rule.get_arg_categories():
        rules_by_arg_category[category].append(rule)
    return immutabledict.immutabledict(rules_by_arg_category)

  def get_start_symbol(self):
    """Returns the syntactic category output by the maximum precedence level.

    Corresponds to nltk.CFG.start().symbol().
    Must be unique. If there is no such unique category, then returns None.
    """
    max_level_categories = self.get_rule_categories_for_level(
        self.get_max_level())
    if len(max_level_categories) == 1:
      return next(iter(max_level_categories))
    else:
      return None

  def get_input_token_usage_counts(self):
    """Returns a map of input token to number of occurrences in the grammar."""
    token_counts = {}
    for rule in self.get_all_rules():
      for term in rule.get_rhs_terms():
        if term.startswith("'") and term.endswith("'"):
          token = term[1:-1]
          token_counts[token] = token_counts.get(token, 0) + 1
    return token_counts

  def get_output_token_usage_counts(self):
    """Returns a map of output token to number of occurrences in the grammar."""
    token_counts = {}
    for rule in self.get_all_rules():
      for term in rule.get_output_sequence():
        if not nltk_utils.is_output_token_a_variable(term):
          token_counts[term] = token_counts.get(term, 0) + 1
    return token_counts

  def validate(self, options):
    """Validates that GrammarSchema is fully-specified and valid.

    Args:
      options: Options with respect to which to validate. Used, for example,
        when validating whether the syntactic categories of each rule or arg
        match the allowable syntactic categories for the relevant precedence
        level.

    Raises:
      ValueError: If any issue was found.
    """
    if not self.primitives:
      raise ValueError('Lacking primitives')

    # Check that there are no levels that are lacking functions
    rules_by_level = self._get_rules_by_level()
    for level in range(self.get_max_level()):
      if not rules_by_level.get(level, []):
        raise ValueError(f'No rules found at level: {level}')

    # Validate category levels
    for level, rules in self._get_rules_by_level().items():
      for rule in rules:
        rule.validate(options)
        rule_category_level = options.level_by_category.get(rule.category)
        if level != rule_category_level:
          raise ValueError(
              f'Category level mismatch: category {rule.category} should be in '
              f'level {rule_category_level} but found in level {level}')

    # Check that there are no unused categories
    rules_by_category = self._get_rules_by_category()
    rules_by_arg_category = self._get_rules_by_arg_category()
    rule_categories = set(rules_by_category)
    arg_categories = set(rules_by_arg_category)
    for category in rule_categories.difference(arg_categories):
      category_level = options.level_by_category.get(category)
      if category_level != self.get_max_level():
        raise ValueError(
            f'Category not consumed by any rule (level '
            f'{category_level} of {self.get_max_level()}): {category}')
    for category in arg_categories.difference(rule_categories):
      raise ValueError(
          f'Category consumed but not produced by any rule: {category}')

    # Check that no categories are output by PassThroughRules alone.
    for rule in self.pass_through_rules.values():
      found_non_pass_through_rule = False
      for other_rule in rules_by_category[rule.category]:
        if not isinstance(other_rule, PassThroughRule):
          found_non_pass_through_rule = True
          break
      if not found_non_pass_through_rule:
        raise ValueError(f'Category output by PassThroughRule must be output '
                         f'by at least one other rule to avoid ambiguity: '
                         f'{rule.category}')

    # Validate concat rule
    if self.concat_rule and not self.concat_rule_level:
      raise ValueError('Concat rule level missing but rule specified')
    if self.concat_rule_level and not self.concat_rule:
      raise ValueError('Concat rule level specified but rule missing')

    # Check that there are no duplicate function names
    level_by_function = {}
    for level, rules in self.functions_by_level.items():
      for rule in rules:
        function_string = rule.get_function_phrase_string()
        if function_string in level_by_function:
          raise ValueError(f'Duplicate function {function_string} in levels '
                           f'{level_by_function[function_string]} and {level}')
        level_by_function[function_string] = level

    # Validate start symbol
    if not self.get_start_symbol():
      raise ValueError(
          f'Start symbol not defined. Found '
          f'{len(self.get_rule_categories_for_level(self.get_max_level()))} '
          f'categories at level {self.get_max_level()}. Expected one.')

  def is_valid(self, options):
    """Returns whether the GrammarSchema is fully-specified and valid.

    Args:
      options: Options with respect to which to validate. Used, for example,
        when validating whether the syntactic categories of each rule or arg
        match the allowable syntactic categories for the relevant precedence
        level.
    """
    try:
      self.validate(options)
    except ValueError:
      return False
    return True

  def to_grammar_string(self):
    """Outputs a grammar in string format corresponding to this GrammarSchema.

    Assumes the GrammarSchema is valid (i.e. is_valid()==True). If called on an
    invalid schema, the result is not guaranteed to be parseable.

    Returns:
      A grammar in a string format that is parseable into nltk.FeatureGrammar.
    """
    lines = []
    lines.append(f'% start {self.get_start_symbol()}')
    # Print highest precedence levels first.
    for level in range(self.get_max_level(), -1, -1):
      rule_strings = self._get_rule_strings_by_level().get(level, [])
      lines.extend(rule_strings)
    return '\n'.join(lines)
