# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for representing a conceptual learning benchmark.

A conceptual learning benchmark, at a high level, measures the ability of a
system to explicitly learn from and about the rules of a task. In order to test
this ability, a conceptual learning dataset and the examples therein are
expected to satisfy a number of key properties, which together guide a large
part of the design of the classes in this library.

These key properties include the following:
- Each Example in the dataset consists of a "request" that needs to be
  translated into a "reply" using some "context".
- The context consists of a set of examples that represent the knowledge based
  on which the translation is to be performed.
- The translation process is describable either fully or in part by a set of
  "rules". These rules could theoretically be of arbitrary form (e.g., grammar
  rules, first order logic rules, knowledge triples, etc.). In this library,
  for maximum flexibility, rules are represented as arbitrary strings.
- Rules can be explicitly referred to in the task by passing the rule itself
  as a request. In that case, the reply is expected to be either TRUE ('1'),
  FALSE ('0'), or UNKNOWN ('?'), indicating whether the given rule holds in the
  given task in the presence of the given context.
- The rules of the task can vary from example to example, and thus must be
  conveyed either explicitly or implicitly via the example's context. Depending
  on how the rule is represented  in the context, we can classify each rule as
  either "explicit" or "hidden". An explicit rule would be represented by a
  corresponding example in the context that explicitly asserts that rule to be
  TRUE ('1'). A hidden rule would not be explicitly asserted to be true, but
  would rather be illustrated indirectly via some number of examples whose
  behavior it affects.
- Both the top-level dataset and the nested contexts will thus consist of a
  mixture of examples that explicitly map potential rules to their truth values
  and examples that illustrate the end-to-end task that follows these rules.
  At the level of the top-level dataset, we refer to these as "rule" and
  "non-rule" examples, respectively.
- Each Example also has a "qualifier", which together with the "reply" makes up
  the output of the task. The qualifier indicates whether the reply holds
  monotonically with respect to the context, or whether it follows only
  defeasibly (e.g., in the case where a rule is induced from examples).

The above properties lead in turn to the following design considerations:
- Since the context of each example is itself a set of examples, this means that
  Example and ExampleSet need to be co-dependent -- each containing the other,
  up to some arbitrary level of recursion.
- In order to support this use case while allowing examples to be organized into
  hash-based data structures such as Sets and Dicts, we introduce two different
  representations of an example set: ExampleSet (which is mutable) and
  FrozenExampleSet (which is immutable). We make Example immutable in order to
  enable a stable hash and accordingly use the immutable FrozenExampleSet for
  representing an example's context. We use the mutable ExampleSet in cases
  where an ExampleSet needs to be built up programmatically, e.g., when
  constructing the benchmark dataset as a whole.
- For bookkeeping purposes, we find it convenient to track some additional
  information about each Example and ExampleSet beyond the basic fields that
  are to be visible to the learner. For example, for each Example we may want to
  track the rules that it depends on, and for each ExampleSet we may want to
  track its full set of rules (including the hidden ones) and the correspondence
  between each rule and the examples that were provided to illustrate it. We
  bundle all such extra information into a "metadata" structure, which is to be
  hidden from the learner.
"""

import abc
import collections.abc
import copy
import enum
import hashlib
import itertools
import json
from typing import (AbstractSet, Any, Callable, Dict, Iterable, Iterator, List,
                    Mapping, MutableMapping, Optional, Sequence, Set, Tuple,
                    TypeVar, Union)

import attr
import nltk
import numpy as np

from conceptual_learning.cscan import enums
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import production_trees

_K = TypeVar('_K')
_V = TypeVar('_V')


def extend_mapping_of_lists_unique(
    target_mapping,
    source_mapping):
  """"Merges the contents of source_mapping into target_mapping w/o duplicates.

  Args:
    target_mapping: Mapping to be extended.
    source_mapping: Mapping whose values are to be merged in. Keys of
      source_mapping that are not already in target_mapping will be added.
      Values from source_mapping that are not already in the corresponding value
      list in target_mapping will be appended.
  """
  for key, values in source_mapping.items():
    if key not in target_mapping:
      target_mapping[key] = []
    for value in values:
      if value not in target_mapping[key]:
        target_mapping[key].append(value)


# Inheriting from str in order to make the enum JSON-serializable. See:
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
@enum.unique
class Qualifier(str, enum.Enum):
  """Indication of the type of reasoning used in a conceptual learning example.

  Specifically, the qualifier indicates whether the reply holds monotonically
  with respect to the context or only defeasibly.
  """
  M = 'M'  # Monotonic (deductive)
  D = 'D'  # Defeasible (inductive)

  def __str__(self):
    return str(self.value)


@enum.unique
class RequestType(str, enum.Enum):
  """Classification of what the conceptual learning request tests."""

  # Example illustrating the behavior of the end-to-end task. (I.e., an example
  # illustrating how the rules of the task are applied to a given input.)
  NON_RULE = 'non_rule_request'

  # Example whose request is a rule, either positive or negative. (Reply should
  # be '0' or '1'.)
  RULE = 'rule_request'


class RuleReply:
  """Dedicated replies to choose from when the request is a rule."""
  FALSE: str = '0'
  TRUE: str = '1'
  UNKNOWN: str = '?'


@enum.unique
class ExampleType(str, enum.Enum):
  """Classification of the example type.

  This allows us to keep track of the number of generated top-level examples
  in each fine-grained bucket.  We expect that for analysis the counts will be
  aggregated across different dimensions to answer questions that is not
  necessarily at the most fine-grained level, such as "how many rule examples
  with true reply did the learner answer correctly", so the enum values are
  designed to be uniformly structured and easy to parse.  Specifically:

  The value of each enum consists of four slots separated by an underscore '_':
  - 'rule' or 'nonrule': for rule or non-rule request type.
  - '' or '?': for known or unknown reply.
  - '' or '0' or '1': for placeholder for non-rule reply, or FALSE/TRUE rule
    reply.
  - 'D' or 'M': for defeasible (inductive) or monotonic (deductive) qualifier.
  """
  NONRULE_KNOWN_D: str = 'nonrule___D'
  NONRULE_KNOWN_M: str = 'nonrule___M'
  NONRULE_UNKNOWN_D: str = 'nonrule_?__D'

  RULE_KNOWN_FALSE_D: str = 'rule__0_D'
  RULE_KNOWN_FALSE_M: str = 'rule__0_M'
  RULE_KNOWN_TRUE_D: str = 'rule__1_D'
  RULE_KNOWN_TRUE_M: str = 'rule__1_M'
  RULE_UNKNOWN_D: str = 'rule_?__D'


EXAMPLE_TYPES_BY_REQUEST_TYPE = {
    RequestType.NON_RULE: [
        ExampleType.NONRULE_KNOWN_D, ExampleType.NONRULE_KNOWN_M,
        ExampleType.NONRULE_UNKNOWN_D
    ],
    RequestType.RULE: [
        ExampleType.RULE_KNOWN_TRUE_D, ExampleType.RULE_KNOWN_TRUE_M,
        ExampleType.RULE_KNOWN_FALSE_D, ExampleType.RULE_KNOWN_FALSE_M,
        ExampleType.RULE_UNKNOWN_D
    ]
}


@enum.unique
class Knownness(str, enum.Enum):
  """Classification of the example by knownness."""
  UNKNOWN: str = 'UNKNOWN'
  KNOWN_DEFEASIBLE: str = 'KNOWN_DEFEASIBLE'
  KNOWN_MONOTONIC: str = 'KNOWN_MONOTONIC'


@enum.unique
class Triviality(str, enum.Enum):
  """Classification of the example by triviality.

  IDENTICAL: If the context contains an example with identical request and reply
    as the top-level example.  (Context, qualifier, and metadata need not be the
    same.)
  REPHRASE_CONTEXT_RULE_AS_NONRULE: If the top-level example is a non-rule
    example with request q and reply r, and the context contains a rule example
    asserting [q] = [r].
  REPHRASE_CONTEXT_NONRULE_AS_RULE: If the top-level example is a rule example
    asserting [q] = [r], and the context contains a non-rule example with
    request q and reply r.
  NEGATION: If the top-level example is a negative rule example directly
    contradicting a rule or non-rule example in the context.
  NON_TRIVIAL: All other cases.
  """
  IDENTICAL: str = 'IDENTICAL'
  REPHRASE_CONTEXT_RULE_AS_NONRULE: str = ('REPHRASE_CONTEXT_RULE_AS_NONRULE')
  REPHRASE_CONTEXT_NONRULE_AS_RULE: str = ('REPHRASE_CONTEXT_NONRULE_AS_RULE')
  NEGATION: str = 'NEGATION'
  NON_TRIVIAL: str = 'NON_TRIVIAL'


@enum.unique
class IllustrationQuality(str, enum.Enum):
  """Classification of how well a rule is illustrated in a context.

  We also use the same enums to indicate whether a top-level example has been
  illustrated "fairly".
  """
  GOOD: str = 'GOOD'
  POOR: str = 'POOR'


@attr.s(auto_attribs=True)
class ExampleTrainSimilarityMetadata:
  """Metadata about an example's similarity to examples from the train set.

  Only relevant for examples from the validation or test set.

  Attributes:
    num_train_examples_with_same_request: Number of examples from the train set
      with the same request as this example.
    num_train_examples_with_same_request_and_reply: Number of examples from the
      train set with the same request and reply as this example.
    num_train_examples_with_same_request_and_output: Number of examples from the
      train set with the same request, reply, and qualifier as this example.
    num_unique_train_replies: Number of different replies observed in the train
      set for examples with the same request as this example.
    nearest_reply_matches: Whether the reply of the nearest neighbor to this
      example from the train set (with similarity measured using context example
      overlap, among train examples with the same request as this one) is the
      same as this one.
    nearest_qualifier_matches: Whether the qualifier of the nearest neighbor to
      this example from the train set (with similarity measured using context
      example overlap, among train examples with the same request as this one)
      is the same as this one.
    nearest_similarity: Similarity of the nearest neighbor to this example from
      the train set, measured in terms of context example overlap, among train
      examples with the same request as this one.
    consensus_reply_matches: Whether the consensus reply determined by voting
      from all train examples with the same request as this one, weighted by
      their similarity to this example in terms of context example overlap, is
      the same as this one.
    consensus_qualifier_matches: Whether the consensus qualifier determined by
      voting from all train examples with the same request as this one, weighted
      by their similarity to this example in terms of context example overlap,
      is the same as this one.
    fraction_train_examples_with_same_request_and_reply: Fraction of the
      same-request examples from the train set that also have the same reply, or
      0.0 if there are no such examples.
    fraction_train_examples_with_same_request_and_output: Fraction of the
      same-request examples from the train set that also have the same reply and
      qualifier, or 0.0 if there are no such examples.
  """
  num_train_examples_with_same_request: int = 0
  num_train_examples_with_same_request_and_reply: int = 0
  num_train_examples_with_same_request_and_output: int = 0
  num_unique_train_replies: int = 0
  nearest_reply_matches: bool = False
  nearest_qualifier_matches: bool = False
  nearest_similarity: float = 0
  consensus_reply_matches: bool = False
  consensus_qualifier_matches: bool = False

  @property
  def fraction_train_examples_with_same_request_and_reply(self):
    if not self.num_train_examples_with_same_request:
      return 0.0
    return (self.num_train_examples_with_same_request_and_reply /
            self.num_train_examples_with_same_request)

  @property
  def fraction_train_examples_with_same_request_and_output(self):
    if not self.num_train_examples_with_same_request:
      return 0.0
    return (self.num_train_examples_with_same_request_and_output /
            self.num_train_examples_with_same_request)


def _example_set_value_serializer(instance, field, value):
  """Value serializer for serialization of Examples and ExampleSets."""
  del instance
  if isinstance(
      value,
      (AbstractExampleContainer, Example, ExampleMetadata, ExampleSetMetadata)):
    # These classes have their own serialization functions.
    return value.serialize()
  elif isinstance(value, nltk.grammar.Production):
    # This is needed because nltk.Production is not directly JSON serializable.
    return json.dumps(str(value))
  elif isinstance(value, production_composition.ProductionProvenance):
    # This is needed because nltk.Production is not directly JSON serializable.
    return value.to_json()
  elif (isinstance(value, collections.abc.Sequence) and field and
        field.name == '_examples'):
    # In ExampleSet, this is a list; in FrozenExampleSet it is a tuple.
    # In unstructured form, this always becomes a list. We need to explicitly
    # call Example.serialize here, as it would not get called by the default
    # serializer for list and tuple types.
    return [v.serialize() for v in value]
  else:
    return value


@attr.s(auto_attribs=True)
class ExampleMetadata:
  """Metadata about a conceptual learning example (hidden from learner).

  We store here any information that we may want to track about an Example, but
  which the learner is not intended to have access to. This information may be
  used, for example, for constructing principled train-test splits, or in
  generation of diagnostic metrics, etc.

  Attributes:
    rules: Rules that were used in generating the given example (and which can
      thus also be considered to be indirectly illustrated in the example). This
      includes the distractor rules.  In case of a nagative rule example, this
      field records the rules used to generate the original request.
    target_rule: The rule which the example is intended to illustrate when it is
      generated.  Populated only for examples in a context.
    derivation_level: The number of times source rule productions are composed
      to generate the example.
    original_reply: Populated only for non-rule examples if the example's reply
      is "?" (unknown), with the generated original reply.
    num_variables: The number of unbound variables of a rule example.
    applied_edits: The list of edits applied to the positive rule to create the
      distractor rule.  Populated only for negative rule examples.
    new_source_production_by_source_production: Mapping from the source
      production string to the new source production string used when generating
      the negative rule example.  Populated only for negative rule examples
      created with the alternative grammar strategy.
    original_request: Populated only if the example is a negative rule example.
    as_rule: The example's request and reply expressed as a rule.  Populated
      only for non-rule examples.
    distractor_rules_by_unreliable_rule: The mapping from the original
      unreliable rules used as dependency when creating the example to the list
      of distractor rules that were actually used.  These distractor rules also
      appear in the rules attribute.
    production: The example expressed as a production.  Used only internally
      during context generation to guarantee consistency.
    production_provenance: The sequence of production compositions via which the
      example was built.
    train_similarity: Metadata about an example's similarity to examples from
      the train set. Only populated in the validation and test splits.
    input_length_standard: Length of the input (context + request) in tokens,
      based on the STANDARD string representation and whitespace tokenizer. Only
      populated for top-level examples.
    output_length_standard: Length of the output (reply + qualifier) in tokens,
      based on the STANDARD string representation and whitespace tokenizer. Only
      populated for top-level examples.
    input_length_compact: Length of the input (context + request) in tokens,
      based on the COMPACT string representation and T5X's default
      SentencePieceVocabulary tokenizer. Only populated for top-level examples.
    output_length_compact: Length of the output (reply + qualifier) in tokens,
      based on the COMPACT string representation and T5X's default
      SentencePieceVocabulary tokenizer. Only populated for top-level examples.
  """

  rules: Set[str] = attr.Factory(set)
  target_rule: str = ''
  derivation_level: int = 0
  original_reply: str = ''
  num_variables: int = 0
  applied_edits: List[str] = attr.Factory(list)
  new_source_production_by_source_production: Dict[str, str] = (
      attr.Factory(dict))
  original_request: str = ''
  as_rule: str = ''
  distractor_rules_by_unreliable_rule: Dict[str, List[str]] = attr.Factory(dict)
  production: Optional[nltk.grammar.Production] = attr.ib(default=None)
  production_provenance: Optional[
      production_composition.ProductionProvenance] = attr.ib(default=None)
  train_similarity: Optional[ExampleTrainSimilarityMetadata] = None

  input_length_standard: int = 0
  output_length_standard: int = 0
  input_length_compact: int = 0
  output_length_compact: int = 0

  def __bool__(self):
    """Returns True if this instance contains any content."""
    return bool(self.rules)

  def iter_distractor_rules(self):
    for distractor_rules in self.distractor_rules_by_unreliable_rule.values():
      for distractor_rule in distractor_rules:
        yield distractor_rule

  def serialize(self):
    """Returns a representation of this object using JSON-serializable types."""
    return attr.asdict(self, value_serializer=_example_set_value_serializer)

  @classmethod
  def deserialize(cls, unstructured_metadata):
    """Returns an ExampleMetadata restored from an unstructured representation.

    Args:
      unstructured_metadata: The contents of an ExampleMetadata in the form
        output by ExampleMetadata.serialize.
    """
    metadata_defaults = ExampleMetadata()
    production = unstructured_metadata.get('production',
                                           metadata_defaults.production)
    if production is not None:
      production = nltk_utils.production_from_production_string(
          json.loads(production))

    production_provenance = unstructured_metadata.get(
        'production_provenance', metadata_defaults.production_provenance)
    if production_provenance is not None:
      production_provenance = (
          production_composition.ProductionProvenance.from_json(
              production_provenance))

    metadata = ExampleMetadata(
        rules=set(unstructured_metadata.get('rules', metadata_defaults.rules)),
        target_rule=unstructured_metadata.get('target_rule',
                                              metadata_defaults.target_rule),
        derivation_level=unstructured_metadata.get(
            'derivation_level', metadata_defaults.derivation_level),
        original_reply=unstructured_metadata.get(
            'original_reply', metadata_defaults.original_reply),
        num_variables=unstructured_metadata.get(
            'num_variables', metadata_defaults.num_variables),
        applied_edits=unstructured_metadata.get(
            'applied_edits', metadata_defaults.applied_edits),
        new_source_production_by_source_production=unstructured_metadata.get(
            'new_source_production_by_source_production',
            metadata_defaults.new_source_production_by_source_production),
        original_request=unstructured_metadata.get(
            'original_request', metadata_defaults.original_request),
        as_rule=unstructured_metadata.get('as_rule', metadata_defaults.as_rule),
        distractor_rules_by_unreliable_rule=unstructured_metadata.get(
            'distractor_rules_by_unreliable_rule',
            metadata_defaults.distractor_rules_by_unreliable_rule),
        production=production,
        production_provenance=production_provenance,
        input_length_standard=unstructured_metadata.get(
            'input_length_standard', metadata_defaults.input_length_standard),
        output_length_standard=unstructured_metadata.get(
            'output_length_standard', metadata_defaults.output_length_standard),
        input_length_compact=unstructured_metadata.get(
            'input_length_compact', metadata_defaults.input_length_compact),
        output_length_compact=unstructured_metadata.get(
            'output_length_compact', metadata_defaults.output_length_compact),
    )
    if unstructured_metadata.get('train_similarity', None):
      metadata.train_similarity = ExampleTrainSimilarityMetadata(
          **unstructured_metadata.get('train_similarity', {}))
    return metadata


@attr.s(auto_attribs=True, frozen=True, repr=False, cache_hash=True)
class Example:
  """Single conceptual learning example.

  Corresponds to a task in which the (context + request) is to be translated
  into the (reply + qualifier).

  Note that this class is designed to be immutable so as to ensure a stable
  hash, which allows it to be used as a key in a dict or as an element in a set.

  Attributes:
    context: A set of examples (either rules or other arbitrary examples)
      representing background knowledge based on which the translation task is
      to be performed. Essentially a what-if scenario.
    request: The request that is to be translated.
    reply: The reply produced in response to the request.
    qualifier: Indication of whether the reply holds monotonically with respect
      to the context or only defeasibly. For a top-level example, this can also
      be interpreted as an indication of whether the reply could be determined
      deductively from the input or whether induction was required.
    metadata: Additional information about the example that is intended to be
      hidden from the learner.
    is_unreliable: Whether the example was generated using at least one
      unreliable rule.
  """

  # Note: The lambda below is required in order to prevent an error due to the
  # forward reference of FrozenExampleSet.

  context: Optional['FrozenExampleSet'] = attr.Factory(
      lambda: FrozenExampleSet())
  request: str = ''
  reply: str = ''
  qualifier: Qualifier = Qualifier.M
  metadata: ExampleMetadata = attr.ib(
      default=attr.Factory(ExampleMetadata), eq=False, order=False)

  def __bool__(self):
    """Returns True if this instance contains any non-trivial public content."""
    return bool(self.context or self.request or self.reply)

  def get_request_type(self):
    """Returns a classification of the behavior that the request tests."""
    if ((self.reply in (RuleReply.TRUE, RuleReply.FALSE)) or
        (self.reply == RuleReply.UNKNOWN and not self.metadata.original_reply)):
      # Rule examples' metadata.original_reply is never set even if its reply
      # is UNKNOWN.  In that case we know that it does not follow from the
      # context, but also does not contradict it.
      return RequestType.RULE
    else:
      return RequestType.NON_RULE

  def get_example_type(self):
    """Returns the ExampleType of the example."""
    request_type = self.get_request_type()
    if request_type == RequestType.NON_RULE:
      request_type_string = 'nonrule'
    else:
      request_type_string = 'rule'

    if self.reply == RuleReply.UNKNOWN:
      known_string = RuleReply.UNKNOWN
    else:
      known_string = ''

    if request_type == RequestType.NON_RULE:
      reply_string = ''
    elif self.reply == RuleReply.UNKNOWN:
      reply_string = ''
    else:
      reply_string = self.reply

    example_type_value = (
        f'{request_type_string}_{known_string}_{reply_string}_{self.qualifier}')

    return ExampleType(example_type_value)

  def get_knownness(self):
    """Returns the knownness of the example."""
    if self.reply == RuleReply.UNKNOWN:
      knownness = Knownness.UNKNOWN
    elif self.qualifier == Qualifier.D:
      knownness = Knownness.KNOWN_DEFEASIBLE
    elif self.qualifier == Qualifier.M:
      knownness = Knownness.KNOWN_MONOTONIC
    return knownness

  def get_triviality(self, context):
    """Returns the triviality of the example with respect to the context."""

    def any_equal_request_and_reply(context):
      return any(((self.request == context_example.request) and
                  (self.reply == context_example.reply))
                 for context_example in context)

    def any_context_request_equal_as_rule(context):
      return any(context_example.request == self.metadata.as_rule
                 for context_example in context)

    def any_context_as_rule_equal_request(context):
      return any(context_example.metadata.as_rule == self.request
                 for context_example in context)

    def any_contradicting_context_example(context):
      # This is the same logic as how inconsistencies are detected in inference
      # engine, but only for examples that actually appear in the context.  We
      # implement the logic here instead of going through the inference engine
      # to keep this function more self-contained.
      # The current implementation is inefficient and rebuilds the input/output
      # tokens mapping each time.  With the more expensive inner functions
      # cached, this does not seem to be a bottleneck for dataset generation
      # speed.
      input_tokens = nltk_utils.extract_rhs_tokens(self.metadata.production)
      output_tokens = nltk_utils.extract_lhs_tokens(self.metadata.production)
      for context_example in context:
        known_input_tokens = nltk_utils.extract_rhs_tokens(
            context_example.metadata.production)
        known_output_tokens = nltk_utils.extract_lhs_tokens(
            context_example.metadata.production)
        if (input_tokens == known_input_tokens and
            output_tokens != known_output_tokens):
          return True
      return False

    if any_equal_request_and_reply(context):
      return Triviality.IDENTICAL
    elif (self.get_request_type() == RequestType.NON_RULE and
          any_context_request_equal_as_rule(context)):
      return Triviality.REPHRASE_CONTEXT_RULE_AS_NONRULE
    elif (self.get_request_type() == RequestType.RULE and
          any_context_as_rule_equal_request(context)):
      return Triviality.REPHRASE_CONTEXT_NONRULE_AS_RULE
    elif (self.get_request_type() == RequestType.RULE and
          any_contradicting_context_example(context)):
      return Triviality.NEGATION
    else:
      return Triviality.NON_TRIVIAL

  @property
  def is_unreliable(self):
    return bool(self.metadata.distractor_rules_by_unreliable_rule)

  def to_string(self, prefix = '', include_metadata = False):
    """Returns a string representation of the Example.

    If the context is empty, the Example is displayed in one line; otherwise,
    each Example of the context is output on a separate line with indentation.

    Example (empty context):
      <{}, Q, R, D>
    Example (nonempty context):
      <{<{}, Q1, R1, M>,
         {}, Q2, R2, M>,}, Q, R, D>
    Example (empty context, with metadata):
      <{}, Q, R, D, metadata=ExampleMetadata(rules={'A=B'})>

    Args:
      prefix: Optional prefix string (typically indicating some kind of
        additional indentation) to output before the standard default
        indentation for each nested example in the case of a non-empty context.
        Used for achieving nested indentation in multi-level contexts.
      include_metadata: Whether to include metadata. By default, excludes
        metadata, so as to output only the information that would be visible to
        a learner when observing this Example in a training set.
    """
    parts = [
        self.context.to_string(prefix, include_metadata), self.request,
        self.reply, self.qualifier
    ]
    if include_metadata and self.metadata:
      parts.append(f'metadata={self.metadata!r}')
    return '<' + ', '.join(parts) + '>'

  def __str__(self):
    return self.to_string()

  def __repr__(self):
    return self.to_string(include_metadata=True)

  def serialize(self):
    """Returns a represention of this object using JSON-serializable types."""
    serialized = attr.asdict(
        self, value_serializer=_example_set_value_serializer)
    return serialized

  @classmethod
  def deserialize(cls, unstructured_example):
    """Returns an Example restored from an unstructured representation.

    Args:
      unstructured_example: The contents of an Example in the form output by
        Example.serialize.
    """
    defaults = Example()
    unstructured_example_metadata = unstructured_example[
        'metadata'] if unstructured_example['metadata'] is not None else {}
    example = Example(
        context=FrozenExampleSet.from_example_set(
            ExampleSet.deserialize(unstructured_example.get('context', ''))),
        request=unstructured_example.get('request', defaults.request),
        reply=unstructured_example.get('reply', defaults.reply),
        qualifier=unstructured_example.get('qualifier', defaults.qualifier),
        metadata=ExampleMetadata.deserialize(unstructured_example_metadata))
    return example

  def to_simple_example(self):
    """Returns an Example that is the same as self but with empty context."""
    return attr.evolve(self, context=FrozenExampleSet())

  def get_md5_hash(self):
    """Returns an md5 hash of the example's identifying content."""
    example_string = self.to_string(prefix=' ', include_metadata=False)
    example_md5_hash = hashlib.md5(example_string.encode('utf-8')).hexdigest()
    return example_md5_hash

  def get_output_pattern(self):
    """Returns the pattern of the example's output sequence."""
    return nltk_utils.output_pattern_from_production(self.metadata.production)


@attr.s(auto_attribs=True)
class ExampleSetTrainSimilarityMetadata:
  """Metadata about a context's similarity to contexts from the train set.

  Only relevant for contexts from the validation or test set.

  Attributes:
    nearest_similarity_by_rule_overlap: Similarity of the nearest neighbor to
      this context from the train set, measured in terms of rule overlap. The
      rules considered here include hidden, unreliable and omitted rules in
      addition to explicitly asserted rules, but do not include derived rules.
    nearest_similarity_by_example_overlap: Similarity of the nearest neighbor to
      this context from the train set, measured in terms of example overlap.
  """
  nearest_similarity_by_rule_overlap: float = 0
  nearest_similarity_by_example_overlap: float = 0


@attr.s(auto_attribs=True)
class ExampleSetMetadata:
  """Metadata about a context ExampleSet (hidden from learner).

  We store here any information that we may want to track about a context,
  but which the learner is not intended to have access to. This information may
  be used, for example, for more efficient lookups, or for constructing
  principled train-test splits, or in generation of diagnostic metrics, etc.

  Note that many of the metadata fields (e.g., the ones tracking rules and
  productions) are intended specifically for ExampleSets that are used as
  contexts, and some are only meaningful in that case. For top-level
  ExampleSets, these metadata fields may not be populated.

  To guarantee consistency between the contents of ExampleSetMetadata and
  the ExampleSet itself, users should not modify the contents of
  ExampleSetMetadata directly, but rather add examples and rules strictly via
  the relevant methods in ExampleSet.

  Attributes:
    rule_format: The format used for representing rules in string format in the
      rule requests and rule-related metadata.
    rules: The full list of rules describing the current task.  This does not
      include the distractor rules.
    grammar: Grammar used to generate the context.
    example_indices: Mapping of each example to its index in the ExampleSet.
    explicit_rules: List of all the explicit rules.
    hidden_rules: List of all the hidden rules.
    unreliable_rules: List of all the unreliable rules.
    omitted_rules: List of all the omitted rules.
    hidden_true_rules: Hidden rules that satisfied the inductive bias and are
      therefore inducible to be true. Populated only after context generation is
      complete.
    hidden_unknown_rules: Hidden rules that did not satisfy the inductive bias
      and are therefore unknown. Populated only after context generation is
      complete.
    distractor_rules_by_unreliable_rule: Mapping from each unreliable rule to
      the list of distractor rules that were actually used in its place in the
      examples.
    distractor_rules: Read-only view on all the distractor rules (created when
      unreliable rules are used while generating examples).
    rule_reply_by_hidden_rule: Mapping from each hidden rule to its truth value
      (TRUE if it satisfies the inductive bias, or UNKNOWN otherwise). Populated
      only after context generation is complete, since the inductive bias may
      behave non-monotonically on a growing context.
    examples_by_rule: List of examples that depend on each rule (including
      distractor rules).
    examples_by_example_type: List of examples of each example type.
    variable_substitutions_by_rule: Mapping from each rule to a mapping of its
      variable substitutions. The variable substitution mapping itself is a
      mapping from each of the rule's variable names (e.g., 'x1') to a set of
      input phrases that the variable was substituted for in any of the examples
      in the context. E.g., for the rule '[x1 twice] = …', if it is used in an
      example <'turn left twice after jump', …>, then the variable substitution
      for x1 would be 'turn left'. If the example were <x1 twice after jump =
      …', …>, then the variable substitution for x1 would be simply 'x1'.
    outer_substitutions_by_rule: Mapping from each rule to a set of its "outer
      substitutions". We define the "outer substitution" to be the string
      corresponding to the node immediately above in the rule application tree
      (skipping over pass-through rule nodes), with the substring corresponding
      to the current rule replaced with '__'. If the given rule is the topmost
      non-pass-through rule in the rule application tree, then we define the
      outer substitution to be simply '__'. E.g., in the example 'turn left
      twice after jump', the outer substitution for '[x1 twice] = …' would be
      would be '__ after jump', while the outer substitution for '[turn] = …'
      would be '__ left', and the outer substitution for '[x1 after x2] = …'
      would be '__'.
    reliable_variable_substitutions_by_rule: Same as
      variable_substitutions_by_rule, but only counting substitutions observed
      in reliable examples.
    reliable_outer_substitutions_by_rule: Same as outer_substitutions_by_rule,
      but only counting substitutions observed in reliable examples.
    train_similarity: Metadata about a context's similarity to contexts from the
      train set. Only populated in the validation and test splits.
  """

  rule_format: Optional[enums.RuleFormat] = None
  rules: List[str] = attr.Factory(list)
  # The reason we don't use grammar for equality is that FeatureGrammar objects
  # return False for equality even though they have the same content. The only
  # case when two Feature grammar objects return True on equality is when they
  # are the same object.
  grammar: Optional[nltk.grammar.FeatureGrammar] = attr.field(
      default=None, eq=False)
  example_indices: Dict[Example, int] = attr.Factory(dict)
  explicit_rules: List[str] = attr.Factory(list)
  hidden_rules: List[str] = attr.Factory(list)
  unreliable_rules: List[str] = attr.Factory(list)
  omitted_rules: List[str] = attr.Factory(list)
  distractor_rules_by_unreliable_rule: Dict[str, List[str]] = attr.Factory(dict)

  rule_reply_by_hidden_rule: Dict[str, str] = attr.Factory(dict)
  examples_by_rule: Dict[str, List[Example]] = attr.Factory(
      lambda: collections.defaultdict(list))
  examples_by_example_type: Dict[ExampleType, List[Example]] = attr.Factory(
      lambda: collections.defaultdict(list))

  variable_substitutions_by_rule: (
      production_trees.VariableSubstitutionsByRule) = (
          attr.Factory(dict))
  outer_substitutions_by_rule: production_trees.OuterSubstitutionsByRule = (
      attr.Factory(dict))
  reliable_variable_substitutions_by_rule: (
      production_trees.VariableSubstitutionsByRule) = (
          attr.Factory(dict))
  reliable_outer_substitutions_by_rule: (
      production_trees.OuterSubstitutionsByRule) = (
          attr.Factory(dict))

  train_similarity: Optional[ExampleSetTrainSimilarityMetadata] = None

  @property
  def distractor_rules(self):
    return [
        rule for rule in self.examples_by_rule.keys()
        if not self.contains_rule(rule)
    ]

  @property
  def hidden_true_rules(self):
    return [
        rule for rule in self.hidden_rules
        if self.rule_reply_by_hidden_rule.get(rule, None) == RuleReply.TRUE
    ]

  @property
  def hidden_unknown_rules(self):
    return [
        rule for rule in self.hidden_rules
        if self.rule_reply_by_hidden_rule.get(rule, None) == RuleReply.UNKNOWN
    ]

  def contains_rule(self, rule):
    """Returns whether the given rule is contained in this context.

    In the case of an unreliable rule, the context is considered to contain
    the original version of the unreliable rule, but not any of the distractor
    variants.

    Args:
      rule: The rule to check whether the ExampleSet contains it.
    """
    return (rule in self.omitted_rules or rule in self.explicit_rules or
            rule in self.unreliable_rules or rule in self.hidden_rules)

  def __deepcopy__(self, memo):
    """Copies the mutable containers, but not the immutable Examples in them.

    This is important, as otherwise the potentially arbitrarily deep nesting of
    ExampleSets in each Example's context in conjunction with the multiple
    references to each Example in the ExampleSetMetadata could lead to an order
    of magnitude or more Example instances when an ExampleSet is deepcopied to
    create a FrozenExampleSet.

    Args:
      memo: Memoization dict (unused). See copy.deepcopy.

    Returns:
      A copy of this instance with newly copied containers, but which still
      references the original Examples.
    """
    result = ExampleSetMetadata(
        # For primitive values, we can use simple assignment.
        rule_format=self.rule_format,
        # For containers of strings, it doesn't matter whether we use deepcopy
        # or ordinary copy.
        rules=self.rules.copy(),
        explicit_rules=self.explicit_rules.copy(),
        hidden_rules=self.hidden_rules.copy(),
        omitted_rules=self.omitted_rules.copy(),
        unreliable_rules=self.unreliable_rules.copy(),
        rule_reply_by_hidden_rule=self.rule_reply_by_hidden_rule.copy(),
        # Here we need to use ordinary copy to avoid copying the Examples.
        example_indices=self.example_indices.copy(),
        # Here we need to use deepcopy because of the nested containers.
        distractor_rules_by_unreliable_rule=copy.deepcopy(
            self.distractor_rules_by_unreliable_rule),
        variable_substitutions_by_rule=copy.deepcopy(
            self.variable_substitutions_by_rule),
        outer_substitutions_by_rule=copy.deepcopy(
            self.outer_substitutions_by_rule),
        reliable_variable_substitutions_by_rule=copy.deepcopy(
            self.reliable_variable_substitutions_by_rule),
        reliable_outer_substitutions_by_rule=copy.deepcopy(
            self.reliable_outer_substitutions_by_rule),
        grammar=copy.deepcopy(self.grammar))
    # We need to copy the below attributes manually because they involve both
    # Examples (which we don't want to deepcopy) and nested containers (which we
    # do need to deepcopy).
    for key, values in self.examples_by_rule.items():
      result.examples_by_rule[key] = values[:]
    for key, values in self.examples_by_example_type.items():
      result.examples_by_example_type[key] = values[:]

    return result

  def _get_num_examples_with_unreliable_rule_or_distractor(
      self, unreliable_rule):
    result = len(self.examples_by_rule.get(unreliable_rule, ()))
    result += sum(
        len(self.examples_by_rule.get(distractor, ()))
        for distractor in self.distractor_rules_by_unreliable_rule.get(
            unreliable_rule, ()))
    return result

  def serialize(self):
    """Returns the desired representation of the metadata when serialized.

    This method transforms the information tracked in the metadata used during
    dataset generation into a format (e.g. counters) more suitable for the
    serialized dataset.
    """

    if self.grammar is not None:
      grammar_string = nltk_utils.grammar_to_string(self.grammar)
    else:
      grammar_string = ''
    metadata_dict = {
        'rule_format': self.rule_format,
        'rules': self.rules.copy(),
        'grammar': grammar_string,
        'omitted_rules': {
            rule: len(self.examples_by_rule.get(rule, ()))
            for rule in self.omitted_rules
        },
        'explicit_rules': {
            rule: len(self.examples_by_rule.get(rule, ()))
            for rule in self.explicit_rules
        },
        'hidden_rules': {
            rule: len(self.examples_by_rule.get(rule, ()))
            for rule in self.hidden_rules
        },
        'unreliable_rules': {
            rule:
            self._get_num_examples_with_unreliable_rule_or_distractor(rule)
            for rule in self.unreliable_rules
        },
        'distractor_rules': {
            rule: len(self.examples_by_rule.get(rule, ()))
            for rule in self.distractor_rules
        },
        'distractor_rules_by_unreliable_rule':
            (self.distractor_rules_by_unreliable_rule.copy()),
        'rule_reply_by_hidden_rule': self.rule_reply_by_hidden_rule.copy(),
        'num_by_example_type': {
            example_type: len(examples_of_example_type) for example_type,
            examples_of_example_type in self.examples_by_example_type.items()
        },
        'min_num_variable_substitutions_by_rule': {
            rule: production_trees.get_effective_min_num_variable_substitutions(
                mapping)
            for rule, mapping in self.variable_substitutions_by_rule.items()
        },
        'num_outer_substitutions_by_rule': {
            rule: production_trees.get_effective_num_outer_substitutions(
                substitutions)
            for rule, substitutions in self.outer_substitutions_by_rule.items()
        },
        'min_num_reliable_variable_substitutions_by_rule': {
            rule: production_trees.get_effective_min_num_variable_substitutions(
                mapping) for rule, mapping in
            self.reliable_variable_substitutions_by_rule.items()
        },
        'num_reliable_outer_substitutions_by_rule': {
            rule: production_trees.get_effective_num_outer_substitutions(
                substitutions) for rule, substitutions in
            self.reliable_outer_substitutions_by_rule.items()
        },
    }
    if self.train_similarity:
      metadata_dict['train_similarity'] = attr.asdict(self.train_similarity)
    return metadata_dict

  def deserialize(self, unstructured_metadata):
    """Recovers the contents of the given serialized ExampleSetMetadata.

    Args:
      unstructured_metadata: The contents of an ExampleSetMetadata in the form
        output by ExampleSetMetadata.serialize.
    """
    if not unstructured_metadata:
      return
    self.rule_format = unstructured_metadata.get('rule_format', None)
    self.rules = list(unstructured_metadata.get('rules', ()))
    grammar_string = unstructured_metadata.get('grammar', '')
    if grammar_string:
      self.grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
    self.explicit_rules = list(unstructured_metadata.get('explicit_rules', ()))
    self.hidden_rules = list(unstructured_metadata.get('hidden_rules', ()))
    self.omitted_rules = list(unstructured_metadata.get('omitted_rules', ()))
    self.unreliable_rules = list(
        unstructured_metadata.get('unreliable_rules', ()))
    self.distractor_rules_by_unreliable_rule = dict(
        unstructured_metadata.get('distractor_rules_by_unreliable_rule', {}))
    self.rule_reply_by_hidden_rule = dict(
        unstructured_metadata.get('rule_reply_by_hidden_rule', {}))
    if unstructured_metadata.get('train_similarity', None):
      self.train_similarity = ExampleSetTrainSimilarityMetadata(
          **unstructured_metadata.get('train_similarity', {}))
    # Note that the following fields do not need to be restored here, as they
    # are restored automatically when the examples themselves are restored:
    # num_by_example_type, [reliable_]variable_substitutions_by_rule,
    # [reliable_]outer_substitutions_by_rule.


def _postprocess_unstructured_example_set_value(value):
  """Performs recursive post-processing on an arbitrary value in an ExampleSet.

  Args:
    value: An arbitrary value within the structure returned by attr.asdict when
      called on an AbstractExampleContainer.

  Returns:
    The same value in a format that matches the actual desired unstructured
    representation of an ExampleSet.
  """
  if isinstance(value, dict):
    return _postprocess_unstructured_example_set_dict(value)
  elif isinstance(value, list):
    return _postprocess_unstructured_example_set_list(value)
  else:
    return value


def _postprocess_unstructured_example_set_list(
    examples_list):
  """Performs recursive post-processing on a list in an ExampleSet.

  Args:
    examples_list: An arbitrary list within the structure returned by
      attr.asdict when called on an AbstractExampleContainer. Normally this will
      be a list of examples.

  Returns:
    The same list in a format that matches the actual desired unstructured
    representation of an ExampleSet.
  """
  postprocessed_list = []
  for v in examples_list:
    postprocessed_list.append(_postprocess_unstructured_example_set_value(v))
  return postprocessed_list


def _postprocess_unstructured_example_set_dict(
    example_dict):
  """Performs recursive post-processing on a dict in an ExampleSet.

  Args:
    example_dict: An arbitrary dict within the structure returned by attr.asdict
      when called on an AbstractExampleContainer. Normally this will be a dict
      representing an unstructured Example.

  Returns:
    The same dict in a format that matches the actual desired unstructured
    representation of an ExampleSet.
  """
  postprocessed_dict = {}
  for k, v in example_dict.items():
    if k == '_requests':
      # Don't serialize this field, as it is just an additional index on
      # existing information and can be reconstructed at deserialization time
      # based on the other fields.
      continue
    elif ((k == 'context' and v is not None and not v['_examples']) or
          (k == 'metadata' and v is not None and not any(v.values()))):
      # Omit empty fields like 'metadata=ExampleSetMetadata()' for conciseness.
      v = None
    postprocessed_dict[k] = _postprocess_unstructured_example_set_value(v)
  return postprocessed_dict


class AbstractExampleContainer(metaclass=abc.ABCMeta):
  """Contains shared functionality of classes that store examples."""

  def _get_asdict_filter(self):

    def asdict_filter(attribute, value):
      del attribute, value
      return True

    return asdict_filter

  def serialize(self):
    """Returns a represention of this object using JSON-serializable types."""
    # pytype: disable=wrong-keyword-args
    example_set_as_dict = attr.asdict(
        self,
        filter=self._get_asdict_filter(),
        value_serializer=_example_set_value_serializer)
    # pytype: enable=wrong-keyword-args
    example_set_as_dict = _postprocess_unstructured_example_set_dict(
        example_set_as_dict)
    return example_set_as_dict

  @abc.abstractmethod
  def to_string(self, prefix = ' ', include_metadata = False):
    """Returns a string representation of the example set.

    Args:
      prefix: Prefix string (typically indicating some kind of indentation) to
        output at the beginning of each line after the first one. Used for
        achieving nested indentation in multi-level contexts.
      include_metadata: Whether to include metadata. By default, excludes
        metadata, so as to output only the information that would be visible to
        a learner when observing this ExampleSet as a training set.
    """

  def __str__(self):
    return self.to_string()

  def get_md5_hash(self):
    """Returns an md5 hash of the example set's identifying content."""
    self_string = str(self)
    self_md5_hash = hashlib.md5(self_string.encode('utf-8')).hexdigest()
    return self_md5_hash


class AbstractExampleSequence(AbstractExampleContainer):
  """Base class of example containers that support a sequence interface.

  Example containers that cannot perform length and random access operations
  in constant time should not inherit this class.
  """

  @abc.abstractmethod
  def _get_examples(self):
    pass

  def __len__(self):
    return len(self._get_examples())

  def __getitem__(self, key):
    return self._get_examples()[key]

  def __contains__(self, example):
    return example in self._get_examples()

  def __iter__(self):
    return iter(self._get_examples())


@attr.s(auto_attribs=True)
class AbstractExampleSequenceWithMetadata(AbstractExampleSequence):
  """Contains shared functionality of example sets with metadata.

  Includes functionalities that depend on the presence of the metadata
  attribute.

  Attributes:
    metadata: Additional information about the example set that is intended to
      be hidden from the learner.
    explicit_fraction: The fraction of explicit rules.
  """
  metadata: ExampleSetMetadata = attr.ib(
      default=attr.Factory(ExampleSetMetadata), eq=False, order=False)

  def __contains__(self, example):
    # With metadata we have a more efficient implementation of __contains__.
    return example in self.metadata.example_indices

  def to_string(self, prefix = ' ', include_metadata = False):
    """Returns a string representation of the example set.

    Each Example is output on a separate line. If an Example has a non-empty
    context, then each Example in the context is again output on a separate
    line with nested indentation.

    Example:
      {<{}, a, b, M>
       <{<{}, c, d, M>
       <{}, e, f, D>}, g, h, M>
       <{}, i, j, D>}

    Args:
      prefix: Prefix string (typically indicating some kind of indentation) to
        output at the beginning of each line after the first one. Used for
        achieving nested indentation in multi-level contexts.
      include_metadata: Whether to include metadata. By default, excludes
        metadata, so as to output only the information that would be visible to
        a learner when observing this ExampleSet as a training set.
    """
    example_strings = [
        example.to_string(prefix + '  ', include_metadata=include_metadata)
        for example in self
    ]
    if include_metadata and self:
      example_strings.append(f'metadata={self.metadata!r}')
    delimiter = '\n' + prefix
    return '{' + delimiter.join(example_strings) + '}'

  @property
  def explicit_fraction(self):
    return len(self.metadata.explicit_rules) / len(self.metadata.rules)


@attr.s(auto_attribs=True, frozen=True, repr=False, cache_hash=True)
class FrozenExampleSet(AbstractExampleSequenceWithMetadata):
  """Immutable version of ExampleSet, for use as the context of an Example.

  Immutable means that it is also hashable, which allows it to be used as a key
  in a dict or as an element in a set.

  Typically created via FrozenExampleSet.from_example_set().
  """

  _examples: Tuple[Example, Ellipsis] = attr.Factory(tuple)
  _requests: AbstractSet[str] = attr.ib(
      init=False, default=attr.Factory(frozenset))

  def __attrs_post_init__(self):
    # With attrs frozen classes, we need to use this trick to set attributes
    # after inititialization.
    object.__setattr__(self, '_requests',
                       frozenset(example.request for example in self))

  def _get_asdict_filter(self):
    allowed_keys = ['_examples', 'metadata']

    def asdict_filter(attribute, value):
      del value
      return attribute.name in allowed_keys

    return asdict_filter

  def _get_examples(self):
    return self._examples

  def __repr__(self):
    return self.to_string(include_metadata=True)

  def request_already_in_example_set(self, request):
    return request in self._requests

  @property
  def is_unreliable(self):
    return bool(self.metadata.unreliable_rules)

  def illustrates_rule(self, rule):
    return (rule in self.metadata.explicit_rules or
            rule in self.metadata.hidden_rules)

  @classmethod
  def from_example_set(cls, dataset):
    """Creates a frozen version of the given ExampleSet."""
    return FrozenExampleSet(
        examples=tuple(iter(dataset)), metadata=copy.deepcopy(dataset.metadata))

  @classmethod
  def from_examples(cls, examples):
    """Creates a FrozenExampleSet containing the given examples."""
    return FrozenExampleSet.from_example_set(ExampleSet.from_examples(examples))


@attr.s(auto_attribs=True, repr=False)
class ExampleSet(AbstractExampleSequenceWithMetadata):
  """Mutable container of conceptual learning examples.

  An ExampleSet can either be used directly to represent a conceptual learning
  dataset as a whole, or be converted to a FrozenExampleSet for use as the
  context of an Example.
  """

  _examples: List[Example] = attr.Factory(list)

  def _get_examples(self):
    return self._examples

  def __repr__(self):
    return self.to_string(include_metadata=True)

  def add_omitted_rule(self, rule):
    self._add_rule(rule)
    self.metadata.omitted_rules.append(rule)

  def add_explicit_rule(self, rule, example):
    """Adds to the ExampleSet an example explicitly asserting the given rule.

    Args:
      rule: The rule to be explicitly asserted.
      example: Example asserting the rule.  If not provided a minimal Example
        will be constructed and used.
    """
    self._add_rule(rule)
    self.metadata.explicit_rules.append(rule)
    self.add_example(example)

  def add_hidden_rule(self, rule,
                      illustrative_examples):
    """Adds to the ExampleSet examples that illustrate the given hidden rule.

    Should be called exactly once for each hidden rule of the original grammar.

    Args:
      rule: Rule to add.
      illustrative_examples: Examples that use that rule.
    """
    self._add_rule(rule)
    self.metadata.hidden_rules.append(rule)
    for example in illustrative_examples:
      self.add_example(example)

  def mark_rule_as_unreliable(self, rule):
    self._add_rule(rule)
    self.metadata.unreliable_rules.append(rule)

  def add_unreliable_rule(self, rule,
                          illustrative_examples):
    if rule not in self.metadata.unreliable_rules:
      raise ValueError(
          f'Rule must first be marked as unreliable before being added: {rule}')
    for example in illustrative_examples:
      self.add_example(example)

  def _add_rule(self, rule):
    if self.metadata.contains_rule(rule):
      raise ValueError(f'Rule already present: {rule}')
    self.metadata.rules.append(rule)

  def add_example(self, example):
    """Adds the given example if it is not already present.

    Args:
      example: Example to be added.

    Returns:
      The given example if added, or else the existing example if an equivalent
      one was already present.
    """
    if example in self:
      return self[self.metadata.example_indices[example]]

    self._examples.append(example)
    self.metadata.example_indices[example] = len(self) - 1
    self.metadata.examples_by_example_type[example.get_example_type()].append(
        example)

    for rule in example.metadata.rules:
      self.metadata.examples_by_rule[rule].append(example)

    extend_mapping_of_lists_unique(
        self.metadata.distractor_rules_by_unreliable_rule,
        example.metadata.distractor_rules_by_unreliable_rule)

    if example.metadata.production_provenance and self.metadata.rule_format:
      production_tree = (
          production_trees.ProductionTree.from_production_provenance(
              example.metadata.production_provenance))
      production_tree.get_variable_substitutions_by_rule(
          self.metadata.variable_substitutions_by_rule,
          self.metadata.rule_format)
      production_tree.get_outer_substitutions_by_rule(
          self.metadata.outer_substitutions_by_rule, self.metadata.rule_format)

      if not example.is_unreliable:
        production_tree.get_variable_substitutions_by_rule(
            self.metadata.reliable_variable_substitutions_by_rule,
            self.metadata.rule_format)
        production_tree.get_outer_substitutions_by_rule(
            self.metadata.reliable_outer_substitutions_by_rule,
            self.metadata.rule_format)

    return example

  @classmethod
  def from_examples(cls, examples):
    """Creates an ExampleSet containing the given examples."""
    dataset = ExampleSet()
    for example in examples:
      dataset.add_example(example)
    return dataset

  @classmethod
  def deserialize(
      cls, unstructured_examples):
    """Returns an ExampleSet restored from an unstructured representation.

    Args:
      unstructured_examples: The contents of an ExampleSet in the form output by
        AbstractExampleContainer.serialize.
    """
    dataset = ExampleSet()
    if unstructured_examples is not None:
      if 'metadata' in unstructured_examples:
        dataset.metadata.deserialize(unstructured_examples['metadata'])
      if '_examples' in unstructured_examples:
        for unstructured_example in unstructured_examples['_examples']:
          dataset.add_example(Example.deserialize(unstructured_example))
    return dataset


@attr.s(auto_attribs=True, repr=False)
class ExampleGroup(AbstractExampleSequence):
  """A mutable example set where all examples share the same context."""

  context: FrozenExampleSet = attr.Factory(FrozenExampleSet)
  _examples: List[Example] = attr.Factory(list)

  def __attrs_post_init__(self):
    self._example_set = set(self._examples)

  def __contains__(self, example):
    # Since __contains__ is called frequently during dataset generation, we
    # provide a more efficient implementation here without using metadata.
    return example in self._example_set

  def _get_examples(self):
    return self._examples

  def to_string(self, prefix = ' ', include_metadata = False):
    """Returns a string representation of the example group.

    The example group is represented as a pair of example sets, the context
    followed by the list of top-level examples.

    Example:
      ({<{}, a, b, M>
        <{}, c, d, D>}
       {<{}, a, b, M>
        <{}, g, h, M>
        <{}, i, j, D>})

    Args:
      prefix: Prefix string (typically indicating some kind of indentation) to
        output at the beginning of each line after the first one. Used for
        achieving nested indentation in multi-level contexts.
      include_metadata: Whether to include metadata. By default, excludes
        metadata, so as to output only the information that would be visible to
        a learner when observing this ExampleSet as a training set.
    """
    context_string = self.context.to_string(
        prefix + ' ', include_metadata=include_metadata)
    example_strings = (
        example.to_string(prefix + '  ', include_metadata=include_metadata)
        for example in self)
    delimiter = '\n' + prefix + ' '
    examples_string = '{' + delimiter.join(example_strings) + '}'
    return f'({context_string}\n{prefix}{examples_string})'

  def __repr__(self):
    return self.to_string(include_metadata=True)

  def to_flat_examples(self):
    """Yields the top-level examples from this object in flattened form.

    Flattened form means they contain their own contexts like in an ordinary
    ExampleSet, rather than having the contexts factored out.
    """
    for example in self:
      yield attr.evolve(example, context=self.context)

  def add_example(self, example):
    """Adds the given example if it is not already present.

    Args:
      example: Example to be added.

    Returns:
      The given example if added, or else None if an equivalent one was already
      present.
    """
    if example in self:
      return None
    self._examples.append(example)
    self._example_set.add(example)
    return example

  def shuffle(self, rng):
    rng.shuffle(self._examples)

  @classmethod
  def deserialize(
      cls, unstructured_example_group):
    """Returns an ExampleGroup restored from an unstructured representation.

    Args:
      unstructured_example_group: The contents of an ExampleGroup in the form
        output by ExampleGroup.serialize.
    """
    mutable_context = ExampleSet.deserialize(
        unstructured_example_group['context'])
    context = FrozenExampleSet.from_example_set(mutable_context)
    examples = [
        Example.deserialize(unstructured_example)
        for unstructured_example in unstructured_example_group['_examples']
    ]

    return cls(context=context, examples=examples)


@attr.s(repr=False)
class GroupedExampleSet(AbstractExampleContainer):
  """An example set where examples are grouped by context."""
  example_groups = attr.ib(type=List[ExampleGroup], default=attr.Factory(list))

  def to_string(self, prefix = ' ', include_metadata = False):
    """Returns a string representation of the grouped example set.

    The example group is represented as a set of example groups.

    Example:
      {({<{}, a, b, M>
         <{}, c, d, D>}
        {<{}, a, b, M>
         <{}, g, h, M>
         <{}, i, j, D>})
       ({<{}, x, y, M>}
        {<{}, u, v, D>})}

    Args:
      prefix: Prefix string (typically indicating some kind of indentation) to
        output at the beginning of each line after the first one. Used for
        achieving nested indentation in multi-level contexts.
      include_metadata: Whether to include metadata. By default, excludes
        metadata, so as to output only the information that would be visible to
        a learner when observing this ExampleSet as a training set.
    """
    example_group_strings = [
        example_group.to_string(
            prefix + ' ', include_metadata=include_metadata)
        for example_group in self.example_groups
    ]

    delimiter = '\n' + prefix
    return '{' + delimiter.join(example_group_strings) + '}'

  def __repr__(self):
    return self.to_string(include_metadata=True)

  def get_contexts(self):
    """Returns iterator over the contexts associated with the example groups."""
    return itertools.chain(
        example_group.context for example_group in self.example_groups)

  def to_flat_examples(self):
    """Yields the top-level examples from this object in flattened form.

    Flattened form means they contain their own contexts like in an ordinary
    ExampleSet, rather than having the contexts factored out in ExampleGroups.
    """
    yield from itertools.chain.from_iterable(
        example_group.to_flat_examples()
        for example_group in self.example_groups)

  def to_example_set(self):
    """Returns the contents of this object represented as a flat ExampleSet."""
    return ExampleSet.from_examples(self.to_flat_examples())

  @classmethod
  def from_example_set(cls, example_set):
    """Returns a GroupedExampleSet converted from an ExampleSet.

    Args:
      example_set: The ExampleSet to be converted.

    Raises:
      ValueError: If any example in a context has nonempty nested context.
    """
    examples_by_context = {}
    for example in example_set:
      context = example.context
      if any(nested_example.context for nested_example in context):
        raise ValueError('ExampleSet with nested context cannot be converted '
                         'to the GroupedExampleSet format.')
      examples = examples_by_context.setdefault(context, [])
      examples.append(example.to_simple_example())

    example_groups = [
        ExampleGroup(context=context, examples=examples)
        for context, examples in examples_by_context.items()
    ]
    return cls(example_groups=example_groups)

  @classmethod
  def deserialize(
      cls, unstructured_grouped_example_set
  ):
    """Returns a GroupedExampleSet restored from unstructured representation.

    Args:
      unstructured_grouped_example_set: The contents of an GroupedExampleSet in
        the form output by GroupedExampleSet.serialize.
    """
    example_groups = [
        ExampleGroup.deserialize(unstructured_example_group)
        for unstructured_example_group in
        unstructured_grouped_example_set['example_groups']
    ]
    return cls(example_groups=example_groups)
