# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for generating random SCAN-like grammars.

As one of the key objectives of a conceptual learning benchmark is to evaluate
the ability of a system to learn from and about the rules of a task, it is
important to be able to vary these rules from example to example.
The GrammarGenerator introduced here is the library that enables this.
Specifically, it enables automatic generation of a variety of rule sets
(i.e., grammars) that serve as the what-if scenarios for the various reasoning
examples in the benchmark.

The space of grammars generated by this library is designed to encompass
grammars of varying size and complexity that are roughly similar in structure
to the SCAN ("Simplified version of the CommAI Navigation tasks") grammar
defined in [Lake and Baroni 2018].

As described in that paper, the original SCAN grammar describes a semantic
parsing task in which a natural language command like "jump twice and turn left"
is to be translated into a command sequence like "JUMP JUMP LTURN". The SCAN
grammar is introduced originally as a simple nonrecursive context-free (CFG)
"phrase-structure" grammar that describes a finite set of permissible input
sentences, together with an "interpretation function" that describes how input
sentences can be unambiguously mapped to corresponding output sequences.

The original SCAN phrase structure grammar contained rules like:
  C -> S and S      # Function rule (with prefix and postfix args)
  S -> V twice      # Function rule (with prefix arg only)
  V -> D            # Pass-through rule
  D -> turn left    # Primitive mapping (with multi-token phrase)
  U -> jump         # Primitive mapping (with single-token phrase)
etc.

The original SCAN interpretation function was described with rules like:
  [[jump]] = JUMP
  [[x1 and x2]] = [[x1] [[x2]]
etc.

As illustrated in the files cscan/data/scan_*.fcfg, however, the same space of
input-output pairs can be equivalently represented using a single
FeatureGrammar, whose CFG core describes the space of permissible input
sentences, while a dedicated feature called "sem" (as in "semantics") is used to
express the input-to-output mappings.

GrammarGenerator and GrammarSchema deal with grammars that are "SCAN-like" in
the sense that they can be represented as FeatureGrammars of the form described
above, which like the SCAN grammar are unambiguous and nonrecursive, but which
can differ from SCAN in having arbitrary choices of input and output vocabulary
and arbitrary numbers of function rules, arguments, primitive mappings, and
function precedence levels.

Papers that inspired this implementation:
* [Lake and Baroni 2018] https://arxiv.org/pdf/1711.00350.pdf
  Original SCAN task (with original SCAN grammar).
* [Lake et al. 2019] https://arxiv.org/pdf/1901.04587.pdf
  MiniSCAN task.
* [Nye et al. 2020] https://arxiv.org/pdf/2003.05562.pdf
  Meta-learning solution to SCAN and MiniSCAN involving auto-generation of
  large numbers of SCAN-like grammars.

Concrete example of the type of grammar that this library might generate:
* scan_finite_nye_standardized.fcfg
"""

import copy
import dataclasses
from typing import Callable, Iterable, List, MutableSequence, Optional, Sequence

from absl import logging
import nltk
import numpy as np

from conceptual_learning.cscan import grammar_schema as gs
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import nltk_utils


def fixed_phrase_structure_template_from_grammar_schema(
    grammar,
    rule_filter = None
):
  """Creates a new GrammarSchema with empty output sequences.

  PassThroughRules are not affected.

  Args:
    grammar: A GrammarSchema.
    rule_filter: If provided, only rules at which rule_filter returns True will
      have empty output sequences and can be varied by a subsequent call to
      GrammarGenerator.generate_grammar.

  Returns:
    A new GrammarSchema identical to the provided one except that it has empty
    output sequences.  This new GrammarSchema is suitable to be used as a
    template in GrammarGenerator.generate_grammar to produce multiple grammars
    sharing the same CFG grammar rules but varying semantics features.
  """
  template = copy.deepcopy(grammar)
  for rule in template.get_all_rules():
    if rule_filter is None or rule_filter(rule):
      rule.output_sequence = []

  return template


@dataclasses.dataclass
class GrammarGenerator:
  """Generator of random SCAN-like FeatureGrammars.

  Specifically, generates grammars of the form describable by GrammarSchema, for
  use in evaluating the ability of a system to learn from and about the rules of
  a task.

  The difficulty of the learning task can be controlled by providing broader or
  narrower constraints around the types of grammars that can be generated.
  Specifically, this can be done in one of two ways, the first being via the
  GrammarOptions that are passed to GrammarGenerator in its constructor. These
  options allow the specification of overall bounds for the size and complexity
  of the grammars that are generated, such as the number of function precedence
  levels, the number of functions per level, the number of primitives, etc.

  The second way is via specification of an optional partially-populated
  GrammarSchema referred to as a "template". If a template GrammarSchema is
  provided, then GrammarGenerator will generate a new GrammarSchema with the
  same contents as the template to the extent that those contents were
  pre-populated, while generating random contents in accordance with the
  GrammarOptions for any values that were left unspecified either in the
  GrammarSchema itself or in any of the objects contained therein.

  As described in the documentation of GrammarSchema, an attribute is generally
  considered to be unspecified if its value is None or an empty List or an empty
  Dict. For details, see the documentation of GrammarSchema and its contained
  classes.

  Typical usage example 1 (GrammarOptions):
    generator = GrammarGenerator(options=inputs.GrammarOptions(...))
    grammar_schema = generator.generate_grammar_schema()

  Typical usage example 2 (GrammarOptions + template GrammarSchema)
    generator = GrammarGenerator(options=inputs.GrammarOptions(...))
    grammar_schema = generator.generate_grammar_schema(
        template=GrammarSchema(...))
    )

  Attributes:
    options: Options controlling the grammar generation process.
    rng: Random number generator.
  """
  options: inputs.GrammarOptions = dataclasses.field(
      default_factory=inputs.GrammarOptions)
  rng: np.random.RandomState = dataclasses.field(
      default_factory=np.random.RandomState)

  def _generate_primitives(self, schema,
                           possible_categories,
                           input_vocabulary,
                           output_vocabulary):
    """Populates schema.primitives with randomly-generated PrimitiveMappings.

    If schema.primitives contains any partially-populated PrimitiveMappings,
    then the missing portions of these objects will be automatically populated
    with random values as well.

    Args:
      schema: Grammar schema to be populated. Existing contents will be honored.
      possible_categories: Available categories. Pops categories from here as
        they are used to avoid re-use later.
      input_vocabulary: Available input tokens. Pops tokens from here as they
        are used to avoid re-use later.
      output_vocabulary: Available output tokens. Pops tokens from here as they
        are used to avoid re-use later.
    """
    # Determine primitive categories.
    categories = sorted(
        list(schema.get_all_categories_for_level(0, self.options)))
    max_num_categories_per_level = min(
        self.options.max_num_categories_per_level, len(possible_categories))
    min_num_categories_per_level = min(
        self.options.min_num_categories_per_level, max_num_categories_per_level)
    if not categories:
      num_categories_for_level = self.rng.randint(
          min_num_categories_per_level, max_num_categories_per_level + 1)
      if num_categories_for_level > len(possible_categories):
        raise ValueError(
            f'Number of primitive categories is bigger than number of possible '
            f'categories available: '
            f'{num_categories_for_level} > {len(possible_categories)}')
      for _ in range(num_categories_for_level):
        categories.append(possible_categories.pop())

    # Determine number of primitives.
    num_primitives = len(schema.primitives)
    if not num_primitives:
      num_primitives = self.options.num_primitives
    if not num_primitives:
      num_primitives = self.rng.randint(self.options.min_num_primitives,
                                        self.options.max_num_primitives + 1)
    logging.debug('num_primitives = %d', num_primitives)

    # Generate primitive stubs.
    num_stubs_to_add = num_primitives - len(schema.primitives)
    schema.primitives.extend(
        gs.PrimitiveMapping() for _ in range(num_stubs_to_add))

    # Populate primitive contents.
    for rule_index, rule in enumerate(schema.primitives):
      if rule.category is None:
        # Make sure that each primitive category is used at least once.
        if rule_index < len(categories):
          category_index = rule_index
        else:
          category_index = self.rng.randint(len(categories))
        rule.category = categories[category_index]
      if not rule.input_sequence:
        rule.input_sequence = [input_vocabulary.pop()]
      if not rule.output_sequence:
        rule.output_sequence = [output_vocabulary.pop()]
    logging.debug('schema.primitives = %s', schema.primitives)

  def _generate_non_primitive_rule_stubs_for_level(self,
                                                   schema,
                                                   level):
    """Populates the given schema level with random numbers of rule stubs.

    Specifically will add empty stubs of rules to functions_by_level,
    pass_through_rules, and/or concat_rule, as appropriate, while leaving any
    existing rules there as-is.

    Args:
      schema: Grammar schema to be populated. Existing contents will be honored.
      level: Precedence level for which to populate the rule stubs.
    """
    # Generate FunctionRule stubs.
    logging.debug('Generating function stubs for level = %d', level)
    num_functions_needed_to_cover_prev_level_categories = len(
        schema.get_rule_categories_for_level(level - 1))
    if level == schema.concat_rule_level:
      num_functions_needed_to_cover_prev_level_categories -= 1
    min_num_functions = max(
        self.options.min_num_functions_per_level,
        num_functions_needed_to_cover_prev_level_categories)
    max_num_functions = max(min_num_functions,
                            self.options.max_num_functions_per_level)
    num_functions = self.rng.randint(min_num_functions, max_num_functions + 1)
    logging.debug('num_functions = %d', num_functions)
    for _ in range(num_functions):
      if level not in schema.functions_by_level:
        schema.functions_by_level[level] = []
      schema.functions_by_level[level].append(gs.FunctionRule())

    # Generate PassThroughRule stub.
    if (self.rng.rand() < self.options.prob_pass_through_rule and
        level not in schema.pass_through_rules):
      logging.debug('Generating PassThroughRule stub for level = %d', level)
      schema.pass_through_rules[level] = gs.PassThroughRule()

    # Generate ConcatRule stub.
    if level == schema.concat_rule_level and not schema.concat_rule:
      logging.debug('Generating ConcatRule stub for level = %d', level)
      schema.concat_rule = gs.ConcatRule()

  def _populate_rule_categories(self, schema, level,
                                categories):
    """Populates the categories output by rules of the given level.

    Specifically will populate the 'category' attribute of any existing rules
    inside functions_by_level, pass_through_rules, and/or concat_rule, which are
    of the given precedence level, and for which the 'category' attribute is not
    yet populated.

    Args:
      schema: Grammar schema to be populated. Rules for precedence levels lower
        than the given level should have already been generated.
      level: Precedence level for which to populate rule categories.
      categories: Categories to choose from. Will ensure that each category is
        output by at least one non-pass-through rule; beyond that, will choose
        categories randomly.
    """
    # We need to shuffle the rules here to ensure that the mapping between rules
    # and categories is random.
    rules = sorted(list(schema.get_rules_for_level(level)), key=str)
    self.rng.shuffle(rules)

    num_non_pass_through_rules_processed = 0
    for rule in rules:
      if rule.category:
        continue
      logging.debug('rule (BEFORE) = %s', rule)
      # Each category should be output by at least one non-pass-through rule.
      if (num_non_pass_through_rules_processed < len(categories) and
          not isinstance(rule, gs.PassThroughRule)):
        category_index = num_non_pass_through_rules_processed
      else:
        category_index = self.rng.randint(len(categories))
      rule.category = categories[category_index]
      if not isinstance(rule, gs.PassThroughRule):
        num_non_pass_through_rules_processed += 1
      logging.debug('rule (AFTER) = %s', rule)

  def _populate_arg_categories(self, schema,
                               level):
    """Populates the categories consumed by args of the given level.

    Specifically will populate the 'category' attribute of any existing args of
    existing rules inside functions_by_level, pass_through_rules, and/or
    concat_rule, for which the rule is of the given precedence level, and for
    which the arg's 'category' attribute is not yet populated.

    Args:
      schema: Grammar schema to be populated. Rules for precedence levels lower
        than the given level should have already been generated.
      level: Precedence level for which to populate arg categories.
    """
    # The categories consumed by the given level need to be the same as the
    # categories output by the previous level.
    categories = sorted(list(schema.get_rule_categories_for_level(level - 1)))
    logging.debug('categories for prev level: %s', categories)

    # We need to shuffle the args here to ensure that the mapping between args
    # and categories is random.
    args = sorted(list(schema.get_args_for_level(level)))
    self.rng.shuffle(args)

    for arg_index, arg in enumerate(args):
      if arg.category:
        continue
      logging.debug('arg (BEFORE) = %s', arg)
      if arg_index < len(categories):
        # Each category should be consumed by at least one arg.
        category_index = arg_index
      else:
        # Beyond that, we choose the categories randomly.
        category_index = self.rng.randint(len(categories))
      arg.category = categories[category_index]
      logging.debug('arg (AFTER) = %s', arg)

  def _populate_arg_variables(self, args):
    """Assigns a variable to each of the given arguments.

    Args:
      args: Arguments to be populated. These are expected to represent the full
        list of arguments of a single RuleSchema, in their original order.
    """
    for arg_index, arg in enumerate(args, 1):
      if arg.variable is None:
        arg.variable = nltk_utils.add_variable_prefix(f'x{arg_index}')

  def _generate_output_sequence(
      self, args,
      output_vocabulary,
      remaining_size):
    """Generates a random output sequence using the given arg variables.

    Args:
      args: Arguments whose variables should be used in the output.
      output_vocabulary: Available output tokens. Pops tokens from here as they
        are used to avoid re-use later.
      remaining_size: A list containing a single integer that is the remaining
        output size. Reduces this number as outputs are generated.

    Returns:
      Sequence of output tokens.
    """

    # Add variables to var_sequence (exactly one for each variable) and
    # output_sequence (the rest).
    output_sequence = []
    var_sequence = []
    for arg in args:
      var_sequence.append(arg.variable)
      num_repetitions = self.rng.randint(
          0, self.options.max_repetitions_per_token_in_output_sequence)
      for _ in range(num_repetitions):
        output_sequence.append(arg.variable)

    # Add raw tokens to raw_sequence.
    raw_sequence = []
    num_raw_tokens = max(
        0,
        self.rng.randint(
            self.options.min_unique_raw_tokens_in_output_sequence,
            self.options.max_unique_raw_tokens_in_output_sequence + 1))
    for _ in range(num_raw_tokens):
      num_repetitions = self.rng.randint(
          1, self.options.max_repetitions_per_token_in_output_sequence + 1)
      token = (
          output_vocabulary[self.rng.randint(0, len(output_vocabulary))]
          if self.options.reuse_raw_tokens_in_output_sequence else
          output_vocabulary.pop())
      for _ in range(num_repetitions):
        raw_sequence.append(token)
    if self.options.max_raw_tokens_in_output_sequence > 0:
      self.rng.shuffle(raw_sequence)
      raw_sequence = raw_sequence[:self.options
                                  .max_raw_tokens_in_output_sequence]

    # Combine variable and raw tokens into output_sequence. Make sure that it
    # does not exceed the maximum size and that each variable appears at least
    # once. Also, perform best effort to not exceed the remaining output size
    # and update it.
    max_size = -1
    if self.options.max_output_sequence_size >= 0:
      if self.options.max_output_sequence_size < len(var_sequence):
        raise ValueError(
            f'Max_output_sequence_size {self.options.max_output_sequence_size} '
            f'must be large enough to contain all variables {len(var_sequence)}.'
        )
      max_size = self.options.max_output_sequence_size
    if remaining_size[0] >= 0 and remaining_size[0] < max_size:
      max_size = remaining_size[0]

    output_sequence.extend(raw_sequence)
    if max_size >= 0:
      self.rng.shuffle(output_sequence)
      output_sequence = output_sequence[:max(0, max_size - len(var_sequence))]
    output_sequence.extend(var_sequence)
    self.rng.shuffle(output_sequence)
    if remaining_size[0] >= 0:
      remaining_size[0] = max(0, remaining_size[0] - len(output_sequence))
    return output_sequence

  def _populate_function_phrase_and_arg_stubs(
      self, rule,
      input_vocabulary):
    """Populates FunctionRule with a function phrase and random number of args.

    Args:
      rule: Rule to be populated. Existing contents will be honored.
      input_vocabulary: Available input tokens. Pops tokens from here as they
        are used to avoid re-use later.
    """
    if not rule.function_phrase:
      rule.function_phrase = [input_vocabulary.pop()]
    if rule.num_args is None:
      rule.num_args = self.rng.randint(self.options.min_num_args,
                                       self.options.max_num_args + 1)
    if rule.num_postfix_args is None:
      rule.num_postfix_args = self.rng.randint(
          min(rule.num_args, self.options.min_num_postfix_args),
          min(rule.num_args, self.options.max_num_postfix_args) + 1)
    # Generate arg stubs if none specified yet.
    if not rule.args:
      for _ in range(rule.num_args):
        rule.args.append(gs.FunctionArg())

  def _determine_categories_for_level(
      self, schema, level, max_level,
      possible_categories):
    """Returns categories to be output by the given level.

    If the schema already contains rules at the given level with categories
    defined, or if there are already rules present at the following level with
    arg categories defined, then those categories will be used as-is. Otherwise,
    categories will be selected randomly.

    Args:
      schema: Grammar schema to be populated. Rules for precedence levels lower
        than the given level should have already been generated.
      level: Precedence level for which to determine the categories.
      max_level: Largest precedence level that will eventually be added to the
        grammar schema.
      possible_categories: Available categories. Pops categories from here as
        they are used to avoid re-use later.
    """
    # Only select new categories if no existing categories are defined.
    categories_for_level = sorted(
        list(schema.get_all_categories_for_level(level, self.options)))
    if categories_for_level:
      return categories_for_level

    # Determine the number of categories to output.
    if level == max_level:
      # The final level must output a single category in order for the grammar
      # to have a well-defined start symbol.
      num_categories_for_level = 1
    else:
      # Make sure the categories are few enough that we can ensure that
      # each category is output by at least one non-pass-through rule.
      num_non_pass_through_rules_for_level = len(
          schema.functions_by_level.get(level, []))
      if level == schema.concat_rule_level:
        num_non_pass_through_rules_for_level += 1
      max_categories = min(self.options.max_num_categories_per_level,
                           num_non_pass_through_rules_for_level,
                           len(possible_categories))
      min_categories = min(self.options.min_num_categories_per_level,
                           max_categories)
      num_categories_for_level = self.rng.randint(min_categories,
                                                  max_categories + 1)
    logging.debug('num_categories_for_level = %d', num_categories_for_level)
    if num_categories_for_level > len(possible_categories):
      raise ValueError(
          f'Num categories for level {level} bigger than number of '
          f'possible categories available: '
          f'{num_categories_for_level} > {len(possible_categories)}')

    # Determine the actual categories to output.
    for _ in range(num_categories_for_level):
      categories_for_level.append(possible_categories.pop())
    logging.debug('categories for level %d = %s', level, categories_for_level)
    return categories_for_level

  def _generate_basic_rules_for_level(
      self, schema, level, max_level,
      should_create_new_non_primitive_rules,
      possible_categories,
      input_vocabulary,
      output_vocabulary):
    """Populates the given schema level with randomly generated rules.

    Specifically will create new rules and/or fill in missing values of existing
    rules of the given precedence level in primitives, functions_by_level,
    pass_through_rules, and/or concat_rule, as appropriate, while leaving any
    existing fully-specified content as-is.

    Args:
      schema: Grammar schema to be populated. Rules for precedence levels lower
        than the given level should have already been generated.
      level: Precedence level for which to generate rules.
      max_level: Largest precedence level that will eventually be added to the
        grammar schema.
      should_create_new_non_primitive_rules: If true, then will create new
        non-primitive rules from scratch. If false, then will only fill in the
        details (category, args, etc.) of rules that are already present.
      possible_categories: Available categories. Pops categories from here as
        they are used to avoid re-use later.
      input_vocabulary: Available input tokens. Pops tokens from here as they
        are used to avoid re-use later.
      output_vocabulary: Available output tokens. Pops tokens from here as they
        are used to avoid re-use later.
    """
    logging.debug('Generating rules for level = %d', level)

    if level == 0:
      self._generate_primitives(schema, possible_categories, input_vocabulary,
                                output_vocabulary)
      return

    # Generate rule stubs.
    if should_create_new_non_primitive_rules:
      self._generate_non_primitive_rule_stubs_for_level(schema, level)

    # Populate rule categories.
    logging.debug('Populating rule categories for level: %d', level)
    categories_for_level = self._determine_categories_for_level(
        schema, level, max_level, possible_categories)
    self._populate_rule_categories(schema, level, categories_for_level)

    # Populate function phrases and arg stubs.
    logging.debug('Populating function phrases and arg stubs for level: %d',
                  level)
    for rule in schema.functions_by_level.get(level, []):
      logging.debug('rule (BEFORE) = %s', rule)
      self._populate_function_phrase_and_arg_stubs(rule, input_vocabulary)
      logging.debug('rule (AFTER) = %s', rule)

    # Populate arg categories.
    logging.debug('Populating arg categories for level: %d', level)
    self._populate_arg_categories(schema, level)

    # Populate arg variables.
    logging.debug('Populating arg variables for level: %d', level)
    for rule in schema.get_rules_for_level(level):
      logging.debug('rule (BEFORE) = %s', rule)
      self._populate_arg_variables(rule.get_args())
      logging.debug('rule (AFTER) = %s', rule)

  def _populate_output_sequences(
      self, schema,
      output_vocabulary):
    """Populates the output sequences of all the rules in the schema.

    Args:
      schema: Grammar schema containing the rules to be populated.
      output_vocabulary: Available output tokens. Pops tokens from here as they
        are used to avoid re-use later.
    """
    logging.debug('Populating output sequences')
    rules = list(schema.get_all_rules())
    self.rng.shuffle(rules)
    remaining_output_size = [self.options.max_cumulative_output_sequence_size]
    for rule in rules:
      if (isinstance(rule, gs.RuleWithExplicitOutputSequence) and
          not rule.output_sequence):
        logging.debug('rule (BEFORE) = %s', rule)
        rule.output_sequence = self._generate_output_sequence(
            rule.get_args(), output_vocabulary, remaining_output_size)
        logging.debug('rule (AFTER) = %s', rule)

  def generate_grammar_schema(self,
                              template = None
                             ):
    """Randomly generates a SCAN-like grammar in GrammarSchema format.

    Args:
      template: Optional partially-specified schema to constrain the types of
        GrammarSchemas that can be generated.

    Returns:
      A newly-generated GrammarSchema. If a template was specified, then this
      will be a clone of the template, with any None-valued attributes (either
      in the GrammarSchema itself or in any of the objects contained therein)
      replaced with values generated randomly in accordance with the options in
      GrammarGenerator. If no template was specified, then the whole
      GrammarSchema will be randomly generated from scratch.
    """
    schema = (
        copy.deepcopy(template) if template is not None else gs.GrammarSchema())
    logging.debug('Generating grammar schema for template: %s', template)

    # For each level, construct a shuffled list of possible categories, so that
    # popping from the list will yield a random selection of categories.
    possible_categories_by_level = {}
    for level, categories in self.options.possible_categories_by_level.items():
      possible_categories_by_level[level] = sorted(list(categories))
      self.rng.shuffle(possible_categories_by_level[level])
    logging.debug('possible_categories_by_level = %s',
                  possible_categories_by_level)

    # Similarly create shuffled lists of available input and output tokens so
    # that popping from them yields random selections of tokens.
    input_vocabulary = sorted(
        list(self.options.input_vocabulary -
             set(schema.get_input_token_usage_counts())))
    self.rng.shuffle(input_vocabulary)
    output_vocabulary = sorted(
        list(self.options.output_vocabulary -
             set(schema.get_output_token_usage_counts())))
    self.rng.shuffle(output_vocabulary)

    # Determine number of precedence levels.
    num_levels = None
    if schema.functions_by_level:
      num_levels = schema.get_max_level()
    if num_levels is None:
      num_levels = self.options.num_precedence_levels
    if num_levels is None:
      num_levels = self.rng.randint(self.options.min_num_precedence_levels,
                                    self.options.max_num_precedence_levels + 1)
    logging.debug('num_levels = %d', num_levels)

    # Determine whether non-primitive rule stubs are already provided.
    should_create_new_non_primitive_rules = not schema.functions_by_level

    # Determine level of the ConcatRule (if any).
    # This step needs to be done early, as it can affect the minimum number of
    # FunctionRules required at any given level.
    if (should_create_new_non_primitive_rules and
        schema.concat_rule_level is None and
        self.rng.rand() < self.options.prob_concat_rule):
      schema.concat_rule_level = self.rng.randint(1, num_levels + 1)
    logging.debug('concat_rule_level = %s', schema.concat_rule_level)

    # Generate rules for each level starting at the lowest (primitives).
    # It is important to do it in this order, as the number and contents of rule
    # arguments of any given level are determined in part based on the number
    # and choice of syntactic categories output by the previous level.
    for level in range(num_levels + 1):
      self._generate_basic_rules_for_level(
          schema, level, num_levels, should_create_new_non_primitive_rules,
          possible_categories_by_level[level], input_vocabulary,
          output_vocabulary)

    self._populate_output_sequences(schema, output_vocabulary)
    logging.debug('schema.functions_by_level = %s', schema.functions_by_level)

    # Ensure that the generated GrammarSchema is valid.
    try:
      schema.validate(self.options)
    except ValueError as e:
      message = f'Generated invalid GrammarSchema: {schema}'
      logging.exception(message)
      raise ValueError(message) from e
    return schema

  def generate_grammar(
      self,
      template = None
  ):
    """Randomly generates a SCAN-like grammar.

    Args:
      template: Optional partially-specified schema to constrain the types of
        GrammarSchemas that can be generated.

    Returns:
      A randomly-generated nltk.FeatureGrammar with SCAN-like rules.
    """
    return nltk.grammar.FeatureGrammar.fromstring(
        self.generate_grammar_schema(template).to_grammar_string())
