# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import textwrap
import traceback
from typing import Optional

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import nltk

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import test_utils


def _get_example_metadata_with_every_serializable_field(
    rule = 'some rule',
    production = (
        nltk_utils.production_from_production_string("B[sem=b] -> 'a'")),
    source_production = (
        nltk_utils.production_from_production_string('B[sem=?x1] -> A[sem=?x1]')
    ),
    composed_production = (
        nltk_utils.production_from_production_string("A[sem=b] -> 'a'")),
    train_similarity = (
        cl.ExampleTrainSimilarityMetadata(
            num_train_examples_with_same_request=1,
            nearest_reply_matches=True,
            nearest_similarity=0.5)),
    reliable = False):
  return cl.ExampleMetadata(
      rules={rule},
      target_rule='d',
      derivation_level=1,
      original_reply='e',
      num_variables=2,
      applied_edits=['_swap'],
      new_source_production_by_source_production={},
      original_request='f',
      as_rule='g',
      distractor_rules_by_unreliable_rule={} if reliable else {'h': ['i']},
      production=production,
      production_provenance=production_composition.ProductionProvenance(
          source=source_production, compositions=((composed_production, 0),)),
      train_similarity=train_similarity,
      input_length_standard=5,
      output_length_standard=6,
      input_length_compact=7,
      output_length_compact=8)


class ConceptualLearningConstantsTest(parameterized.TestCase):

  # This is how each constant will appear in an Example converted to string.
  @parameterized.named_parameters(
      ('M', cl.Qualifier.M, 'M'), ('D', cl.Qualifier.D, 'D'),
      ('FALSE', cl.RuleReply.FALSE, '0'), ('TRUE', cl.RuleReply.TRUE, '1'),
      ('UNKNOWN', cl.RuleReply.UNKNOWN, '?'))
  def test_enum_to_string(self, enum_value, string_value):
    self.assertEqual(string_value, str(enum_value))

  # This is how each constant will appear in an Example converted to JSON.
  @parameterized.named_parameters(
      ('M', cl.Qualifier.M, '"M"'), ('D', cl.Qualifier.D, '"D"'),
      ('FALSE', cl.RuleReply.FALSE, '"0"'), ('TRUE', cl.RuleReply.TRUE, '"1"'),
      ('UNKNOWN', cl.RuleReply.UNKNOWN, '"?"'))
  def test_enum_to_json(self, enum_value, string_value):
    self.assertEqual(string_value, json.dumps(enum_value))

  def test_example_types_by_request_type(self):
    request_types = cl.EXAMPLE_TYPES_BY_REQUEST_TYPE.keys()
    with self.subTest('covers_all_request_types'):
      self.assertCountEqual(request_types, cl.RequestType)

    example_types = []
    for example_type_list in cl.EXAMPLE_TYPES_BY_REQUEST_TYPE.values():
      example_types.extend(example_type_list)
    with self.subTest('covers_all_example_types'):
      self.assertCountEqual(example_types, cl.ExampleType)


class ExampleTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('Empty', cl.Example(), cl.RequestType.NON_RULE),
      ('Non-rule', cl.Example(request='a', reply='b'), cl.RequestType.NON_RULE),
      ('Non-rule defeasible',
       cl.Example(request='a', reply='b',
                  qualifier=cl.Qualifier.D), cl.RequestType.NON_RULE),
      ('Non-rule unknown reply',
       cl.Example(
           request='a',
           reply=cl.RuleReply.UNKNOWN,
           metadata=cl.ExampleMetadata(original_reply='original reply')),
       cl.RequestType.NON_RULE),
      ('Monotonic proposition with reply=TRUE',
       cl.Example(request='a', reply=cl.RuleReply.TRUE), cl.RequestType.RULE),
      ('Monotonic proposition with reply=FALSE',
       cl.Example(request='a', reply=cl.RuleReply.FALSE), cl.RequestType.RULE),
      ('Defeasible proposition with reply=TRUE',
       cl.Example(
           request='a', reply=cl.RuleReply.TRUE,
           qualifier=cl.Qualifier.D), cl.RequestType.RULE),
      ('Defeasible rule with reply=FALSE',
       cl.Example(
           request='a', reply=cl.RuleReply.FALSE,
           qualifier=cl.Qualifier.D), cl.RequestType.RULE),
      ('Unknown reply rule example',
       cl.Example(
           request='a', reply=cl.RuleReply.UNKNOWN,
           qualifier=cl.Qualifier.D), cl.RequestType.RULE))
  def test_get_request_type(self, example, expected_value):
    self.assertEqual(expected_value, example.get_request_type())

  @parameterized.named_parameters(
      ('Empty', cl.Example(), cl.ExampleType.NONRULE_KNOWN_M),
      ('Non-rule', cl.Example(request='a',
                              reply='b'), cl.ExampleType.NONRULE_KNOWN_M),
      ('Non-rule defeasible',
       cl.Example(request='a', reply='b',
                  qualifier=cl.Qualifier.D), cl.ExampleType.NONRULE_KNOWN_D),
      ('Monotonic rule with reply=TRUE',
       cl.Example(request='a',
                  reply=cl.RuleReply.TRUE), cl.ExampleType.RULE_KNOWN_TRUE_M),
      ('Monotonic rule with reply=FALSE',
       cl.Example(request='a',
                  reply=cl.RuleReply.FALSE), cl.ExampleType.RULE_KNOWN_FALSE_M),
      ('Defeasible rule with reply=TRUE',
       cl.Example(
           request='a', reply=cl.RuleReply.TRUE,
           qualifier=cl.Qualifier.D), cl.ExampleType.RULE_KNOWN_TRUE_D),
      ('Defeasible rule with reply=FALSE',
       cl.Example(
           request='a', reply=cl.RuleReply.FALSE,
           qualifier=cl.Qualifier.D), cl.ExampleType.RULE_KNOWN_FALSE_D),
      ('Unknown reply rule example',
       cl.Example(
           request='a', reply=cl.RuleReply.UNKNOWN,
           qualifier=cl.Qualifier.D), cl.ExampleType.RULE_UNKNOWN_D))
  def test_get_example_type(self, example, expected_value):
    self.assertEqual(expected_value, example.get_example_type())

  @parameterized.named_parameters(
      ('Empty', cl.Example(), cl.Knownness.KNOWN_MONOTONIC),
      ('Non-rule', cl.Example(request='a',
                              reply='b'), cl.Knownness.KNOWN_MONOTONIC),
      ('Non-rule defeasible',
       cl.Example(request='a', reply='b',
                  qualifier=cl.Qualifier.D), cl.Knownness.KNOWN_DEFEASIBLE),
      ('Non-rule unknown reply',
       cl.Example(request='a', reply=cl.RuleReply.UNKNOWN),
       cl.Knownness.UNKNOWN), ('Monotonic rule with reply=TRUE',
                               cl.Example(request='a', reply=cl.RuleReply.TRUE),
                               cl.Knownness.KNOWN_MONOTONIC),
      ('Monotonic rule with reply=FALSE',
       cl.Example(request='a',
                  reply=cl.RuleReply.FALSE), cl.Knownness.KNOWN_MONOTONIC),
      ('Defeasible rule with reply=TRUE',
       cl.Example(
           request='a', reply=cl.RuleReply.TRUE,
           qualifier=cl.Qualifier.D), cl.Knownness.KNOWN_DEFEASIBLE),
      ('Defeasible rule with reply=FALSE',
       cl.Example(
           request='a', reply=cl.RuleReply.FALSE,
           qualifier=cl.Qualifier.D), cl.Knownness.KNOWN_DEFEASIBLE),
      ('Unknown reply with original_reply=FALSE',
       cl.Example(
           request='a',
           reply=cl.RuleReply.UNKNOWN,
           qualifier=cl.Qualifier.D,
           metadata=cl.ExampleMetadata(original_reply=cl.RuleReply.FALSE)),
       cl.Knownness.UNKNOWN),
      ('Unknown reply with original_reply=TRUE',
       cl.Example(
           request='a',
           reply=cl.RuleReply.UNKNOWN,
           qualifier=cl.Qualifier.D,
           metadata=cl.ExampleMetadata(original_reply=cl.RuleReply.TRUE)),
       cl.Knownness.UNKNOWN))
  def test_get_knownness(self, example, expected_value):
    self.assertEqual(expected_value, example.get_knownness())

  @parameterized.named_parameters(
      ('identical_nonrule', cl.Example(request='a',
                                       reply='b'), cl.Triviality.IDENTICAL),
      ('identical_rule', cl.Example(
          request='c = d', reply=cl.RuleReply.TRUE), cl.Triviality.IDENTICAL),
      ('rephrase_context_rule_as_nonrule',
       cl.Example(
           request='c', reply='d',
           metadata=cl.ExampleMetadata(as_rule='c = d')),
       cl.Triviality.REPHRASE_CONTEXT_RULE_AS_NONRULE),
      ('rephrase_context_nonrule_as_rule',
       cl.Example(request='a = b', reply=cl.RuleReply.TRUE),
       cl.Triviality.REPHRASE_CONTEXT_NONRULE_AS_RULE),
      ('negation',
       cl.Example(
           request='c = not_d',
           reply=cl.RuleReply.FALSE,
           metadata=cl.ExampleMetadata(
               production=nltk_utils.production_from_production_string(
                   "A[sem=(WALK+WALK)] -> 'walk'"))), cl.Triviality.NEGATION),
      ('non_trivial_nonrule', cl.Example(request='x',
                                         reply='y'), cl.Triviality.NON_TRIVIAL),
      ('non_trivial_positive_rule',
       cl.Example(request='x = y',
                  reply=cl.RuleReply.TRUE), cl.Triviality.NON_TRIVIAL),
      ('non_trivial_negative_rule',
       cl.Example(request='x = not_y',
                  reply=cl.RuleReply.FALSE), cl.Triviality.NON_TRIVIAL))
  def test_get_triviality(self, example, expected):
    context = cl.FrozenExampleSet.from_examples([
        cl.Example(
            request='a',
            reply='b',
            metadata=cl.ExampleMetadata(
                as_rule='a = b',
                production=nltk_utils.production_from_production_string(
                    "A[sem=WALK] -> 'walk'"))),
        cl.Example(request='c = d', reply=cl.RuleReply.TRUE)
    ])
    triviality = example.get_triviality(context)
    self.assertEqual(triviality, expected)

  @parameterized.named_parameters(
      ('Empty', cl.Example(), False),
      ('Request non-empty', cl.Example(request=' '), True),
      ('Reply non-empty', cl.Example(reply=' '), True))
  def test_implicit_bool_conversion(self, example, expected_value):
    self.assertEqual(expected_value, bool(example))

  @parameterized.named_parameters(
      ('Empty', cl.Example(), '<{}, , , M>'),
      ('Non-empty', cl.Example(
          request='a', reply='b', qualifier=cl.Qualifier.D), '<{}, a, b, D>'))
  def test_to_string_without_metadata(self, example, expected_string):
    with self.subTest('to_string'):
      self.assertEqual(expected_string, example.to_string())
    with self.subTest('str'):
      self.assertEqual(expected_string, str(example))
    with self.subTest('repr_is_same_as str_when_metadata_is_empty'):
      self.assertEqual(expected_string, repr(example))

  def test_to_string_with_metadata(self):
    example = cl.Example(
        request='a',
        reply='b',
        qualifier=cl.Qualifier.D,
        metadata=cl.ExampleMetadata(
            rules={'A=B'},
            target_rule='A=B',
            derivation_level=1,
            num_variables=2))

    expected_string = '<{}, a, b, D>'

    actual_repr = repr(example)

    with self.subTest('to_string_is_same_as_str_by_default'):
      self.assertEqual(expected_string, example.to_string())

    with self.subTest('str_only_shows_public_content_not_metadata'):
      self.assertEqual(expected_string, str(example))

    with self.subTest('repr_subsumes_str'):
      self.assertContainsSubsequence(actual_repr, expected_string)

    with self.subTest('repr_includes_metadata'):
      self.assertContainsExactSubsequence(actual_repr, 'ExampleMetadata')

  @parameterized.named_parameters(
      ('without_train_similarity_metadata',
       _get_example_metadata_with_every_serializable_field(
           train_similarity=None)),
      ('with_every_serializable_field',
       _get_example_metadata_with_every_serializable_field()),
  )
  def test_roundtrip_serialize_with_metadata(self, original_metadata):
    original = cl.Example(request='a', reply='b', metadata=original_metadata)
    logging.info('original = %r', original)
    try:
      unstructured = original.serialize()
    except Exception:
      self.fail(f'Exception raised when serializing Example to JSON:\n'
                f'  Example: {original}\n'
                f'  Error: {traceback.format_exc()}')
    logging.info('unstructured = %s', unstructured)
    try:
      restored = cl.Example.deserialize(unstructured)
    except Exception:
      self.fail(f'Exception raised when deserializing JSON Example:\n'
                f'  JSON: {unstructured}\n'
                f'  Error: {traceback.format_exc()}')
    logging.info('restored = %r', restored)
    with self.subTest('restores_example'):
      self.assertEqual(original, restored)
    with self.subTest('restores_example_metadata'):
      self.assertEqual(
          original.metadata,
          restored.metadata,
          msg=f'\n{original.metadata}\n\n{restored.metadata}')

  def test_to_simple_example(self):
    example = cl.Example(
        request='a',
        reply='b',
        qualifier=cl.Qualifier.D,
        metadata=cl.ExampleMetadata(
            rules={'A=B'}, target_rule='A=B', derivation_level=1))
    simple_example = example.to_simple_example()

    with self.subTest('should_clear_context'):
      self.assertEqual(cl.FrozenExampleSet(), simple_example.context)
    with self.subTest('should_keep_request'):
      self.assertEqual(example.request, simple_example.request)
    with self.subTest('should_keep_reply'):
      self.assertEqual(example.reply, simple_example.reply)
    with self.subTest('should_keep_qualifier'):
      self.assertEqual(example.qualifier, simple_example.qualifier)
    with self.subTest('should_keep_example_metadata'):
      self.assertEqual(example.metadata, simple_example.metadata)


def _create_small_mutable_context_with_every_serializable_field(
):
  context = cl.ExampleSet()
  context.metadata.rule_format = enums.RuleFormat.INTERPRETATION_RULE
  context.add_example(
      cl.Example(
          request='c',
          reply='d',
          metadata=_get_example_metadata_with_every_serializable_field()))
  return context


def _create_mutable_context_exercising_every_field_and_method(
    omitted_rule = 'E=F',
    explicit_rule = 'A=B',
    hidden_rule = 'C=D',
    unreliable_rule = 'K=L',
    distractor_rule = 'K=LA',
    grammar_string = """
    U[sem='F'] -> 'E'
    U[sem='B'] -> 'A'
    U[sem='D'] -> 'C'
    U[sem='K'] -> 'L'"""
):
  """Returns a mutable context with a variety of contents for testing."""
  context = _create_small_mutable_context_with_every_serializable_field()
  context.add_omitted_rule(omitted_rule)
  context.add_explicit_rule(
      explicit_rule,
      test_utils.create_example_from_explicit_rule(explicit_rule))
  context.add_hidden_rule(hidden_rule, [
      cl.Example(
          request='c1',
          reply='d1',
          metadata=cl.ExampleMetadata(
              rules={hidden_rule}, target_rule=hidden_rule)),
      cl.Example(
          request='c2',
          reply='d2',
          metadata=cl.ExampleMetadata(
              rules={hidden_rule}, target_rule=hidden_rule))
  ])
  context.mark_rule_as_unreliable(unreliable_rule)
  context.add_unreliable_rule(unreliable_rule, [
      cl.Example(
          request='k1',
          reply='l1',
          metadata=cl.ExampleMetadata(
              rules={distractor_rule},
              distractor_rules_by_unreliable_rule={
                  unreliable_rule: [distractor_rule]
              })),
  ])
  context.metadata.rule_reply_by_hidden_rule = {
      hidden_rule: cl.RuleReply.TRUE,
  }
  grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
  context.metadata.grammar = grammar
  return context


def _create_context():
  """Returns a context with a variety of contents for testing."""
  return cl.FrozenExampleSet.from_example_set(
      _create_mutable_context_exercising_every_field_and_method())


def _create_example_set(
    train_similarity = None
):
  """Returns a top-level ExampleSet with a variety of contents for testing."""
  context = _create_context()
  dataset = cl.ExampleSet.from_examples([
      cl.Example(context=context, request='a', reply='b'),
      cl.Example(
          request='c',
          reply='d',
          qualifier=cl.Qualifier.D,
          metadata=_get_example_metadata_with_every_serializable_field())
  ])
  dataset.metadata.train_similarity = train_similarity
  return dataset


def _create_example_group(with_context=True):
  if with_context:
    context = _create_context()
  else:
    context = cl.FrozenExampleSet()

  return cl.ExampleGroup(
      context=context,
      examples=[
          cl.Example(request='a', reply='b'),
          cl.Example(request='c', reply='d')
      ])


def _create_grouped_example_set():
  example_group_0 = _create_example_group(with_context=True)
  example_group_1 = _create_example_group(with_context=False)

  return cl.GroupedExampleSet(example_groups=[example_group_0, example_group_1])


class ExampleSetMetadataTest(absltest.TestCase):

  def test_rule_reply_by_hidden_rules_empty(self):
    metadata = cl.ExampleSetMetadata(hidden_rules=['a', 'b'])
    self.assertEqual(metadata.hidden_true_rules, [])
    self.assertEqual(metadata.hidden_unknown_rules, [])

  def test_rule_reply_by_hidden_rules_non_empty(self):
    metadata = cl.ExampleSetMetadata(
        hidden_rules=['a', 'b'],
        rule_reply_by_hidden_rule={
            'a': cl.RuleReply.TRUE,
            'b': cl.RuleReply.UNKNOWN
        })
    self.assertEqual(metadata.hidden_true_rules, ['a'])
    self.assertEqual(metadata.hidden_unknown_rules, ['b'])

  def test_serialize_variable_substitutions(self):
    original = cl.ExampleSetMetadata(
        rule_format=enums.RuleFormat.INTERPRETATION_RULE,
        variable_substitutions_by_rule={
            'rule1': {
                'x1': {'a', 'b', 'c'},
                'x2': {'d', 'e'},
            },
            'rule2': {
                'x1': {'x1'}
            },
            'rule3': {},
        },
        outer_substitutions_by_rule={
            'rule1': {'__ f', 'g __'},
            'rule2': {'__ h', '__'},
            'rule3': {'__'},
        },
        reliable_variable_substitutions_by_rule={
            'rule1': {
                'x1': {'a'}
            },
        },
        reliable_outer_substitutions_by_rule={'rule1': {'__ f', 'g __'}})

    unstructured = original.serialize()

    with self.subTest('rule_format'):
      self.assertEqual(
          unstructured.get('rule_format', None), 'INTERPRETATION_RULE')
    with self.subTest('variable_substitutions_by_rule'):
      self.assertEqual(
          unstructured.get('min_num_variable_substitutions_by_rule', None),
          {
              # Minimum of the substitution counts across variables, with the
              # special value 1000000 ("infinite") used in cases where there was
              # no variable, or when all variable were observed in unsubstituted
              # form.
              'rule1': 2,
              'rule2': 1000000,
              'rule3': 1000000,
          })
    with self.subTest('outer_substitutions_by_rule'):
      self.assertDictEqual(
          unstructured.get('num_outer_substitutions_by_rule', None),
          {
              # Count of the substitutions, with the special values 1000000
              # ("infinite") used in cases where the rule occurred at least once
              # as the top of the rule application tree (i.e., with outer
              # context of '__').
              'rule1': 2,
              'rule2': 1000000,
              'rule3': 1000000,
          })
    with self.subTest('reliable_variable_substitutions_by_rule'):
      self.assertEqual(
          unstructured.get('min_num_reliable_variable_substitutions_by_rule',
                           None), {
                               'rule1': 1,
                           })
    with self.subTest('reliable_outer_substitutions_by_rule'):
      self.assertDictEqual(
          unstructured.get('num_reliable_outer_substitutions_by_rule', None), {
              'rule1': 2,
          })

  def test_roundtrip_serialize_grammar(self):
    grammar_string = """
    C[sem=(?x1+?x2)] -> S[sem=?x1] and S[sem=?x2]
    S[sem=(?x1+?x1)] -> U[sem=?x1] 'twice'
    U[sem='JUMP'] -> 'jump'
    U[sem='WALK'] -> 'walk'"""
    examples_with_metadata = (
        _create_mutable_context_exercising_every_field_and_method(
            grammar_string=grammar_string))
    unstructured_metadata = examples_with_metadata.metadata.serialize()
    restored_metadata = cl.ExampleSetMetadata()
    restored_metadata.deserialize(unstructured_metadata)
    expected_productions = [
        nltk_utils.production_from_production_string(
            'C[sem=(?x1+?x2)] -> S[sem=?x1] and S[sem=?x2]'),
        nltk_utils.production_from_production_string(
            "S[sem=(?x1+?x1)] -> U[sem=?x1] 'twice'"),
        nltk_utils.production_from_production_string("U[sem='JUMP'] -> 'jump'"),
        nltk_utils.production_from_production_string(
            " U[sem='WALK'] -> 'walk'"),
    ]
    self.assertCountEqual(expected_productions,
                          restored_metadata.grammar.productions())


class ExampleSetTest(parameterized.TestCase):

  def test_add_example(self):
    dataset = cl.ExampleSet()
    original_example = cl.Example(request='a', reply='b')
    dataset.add_example(original_example)
    returned_example = dataset.add_example(original_example)
    with self.subTest('does_not_unnecessarily_copy_the_example'):
      self.assertIs(original_example, dataset[0])
    with self.subTest('returns_the_newly_inserted_example'):
      self.assertIs(original_example, returned_example)

  @parameterized.named_parameters(('reliable_example', True),
                                  ('unreliable_example', False))
  def test_add_example_should_populate_metadata(self, reliable):
    dataset = cl.ExampleSet()
    dataset.metadata.rule_format = enums.RuleFormat.INTERPRETATION_RULE
    rule = 'some rule'
    production = nltk_utils.production_from_production_string("B[sem=b] -> 'a'")
    production_0 = nltk_utils.production_from_production_string(
        'B[sem=?x1] -> A[sem=?x1]')
    production_1 = nltk_utils.production_from_production_string(
        "A[sem=b] -> 'a'")
    example = cl.Example(
        request='a',
        reply='b',
        metadata=_get_example_metadata_with_every_serializable_field(
            rule=rule,
            production=production,
            source_production=production_0,
            composed_production=production_1,
            reliable=reliable))
    dataset.add_example(example)

    expected_variable_substitutions_by_rule = {'[a] = b': {}}
    expected_outer_substitutions_by_rule = {'[a] = b': {'__'}}

    if reliable:
      expected_reliable_variable_substitutions_by_rule = (
          expected_variable_substitutions_by_rule)
      expected_reliable_outer_substitutions_by_rule = (
          expected_outer_substitutions_by_rule)
    else:
      expected_reliable_variable_substitutions_by_rule = {}
      expected_reliable_outer_substitutions_by_rule = {}

    with self.subTest('metadata_should_track_examples_by_rule'):
      self.assertIn(rule, dataset.metadata.examples_by_rule)
      self.assertEqual(
          dataset.metadata.examples_by_rule.get(rule, []), [example])

    with self.subTest('metadata_should_track_examples_by_example_type'):
      self.assertEqual(
          dataset.metadata.examples_by_example_type[example.get_example_type()],
          [example])

    with self.subTest('metadata_should_track_variable_substitutions_by_rule'):
      self.assertDictEqual(expected_variable_substitutions_by_rule,
                           dataset.metadata.variable_substitutions_by_rule)

    with self.subTest('metadata_should_track_outer_substitutions_by_rule'):
      self.assertDictEqual(expected_outer_substitutions_by_rule,
                           dataset.metadata.outer_substitutions_by_rule)

    with self.subTest(
        'metadata_should_track_reliable_variable_substitutions_by_rule'):
      self.assertDictEqual(
          expected_reliable_variable_substitutions_by_rule,
          dataset.metadata.reliable_variable_substitutions_by_rule)

    with self.subTest(
        'metadata_should_track_reliable_outer_substitutions_by_rule'):
      self.assertDictEqual(
          expected_reliable_outer_substitutions_by_rule,
          dataset.metadata.reliable_outer_substitutions_by_rule)

  def test_add_example_duplicate_example(self):
    example1 = cl.Example(request='a', reply='b')
    example2 = cl.Example(request='c', reply='d')
    example3 = cl.Example(request='a', reply='b')
    dataset = cl.ExampleSet.from_examples([example1, example2])
    returned_example = dataset.add_example(example3)
    with self.subTest('does_not_add_the_duplicate_example'):
      self.assertLen(dataset, 2)
    with self.subTest('returns_the_existing_equivalent_example'):
      self.assertIs(example1, returned_example)

  def test_add_omitted_rule(self):
    dataset = cl.ExampleSet()
    dataset.add_omitted_rule('A=B')
    dataset.add_omitted_rule('C=D')

    with self.subTest(
        name='rules_and_corresponding_examples_should_be_added_to_the_correct_'
        'data_structures_in_ExampleSetMetadata'):
      self.assertEmpty(dataset, 0)
      self.assertLen(dataset.metadata.rules, 2)
      self.assertLen(dataset.metadata.omitted_rules, 2)
      self.assertEmpty(dataset.metadata.explicit_rules)
      self.assertEmpty(dataset.metadata.hidden_rules)
      self.assertEmpty(dataset.metadata.example_indices)

    with self.subTest('should_now_contain_the_rule'):
      self.assertTrue(dataset.metadata.contains_rule('A=B'))

    with self.subTest('omitted_rules_should_also_contain_the_rule'):
      self.assertIn('A=B', dataset.metadata.omitted_rules)

    with self.subTest(
        name='should_be_able_to_check_for_membership_in_the_rule_and_example_'
        'sets_using_a_newly_created_but_equivalent_string_or_Example_object'):
      self.assertTrue(dataset.metadata.contains_rule('A=B'))

    with self.subTest(
        name='should_preserve_the_order_in_which_the_rules_were_added'):
      self.assertEqual('A=B', dataset.metadata.rules[0])

  def test_add_omitted_rule_raises_error_if_rule_already_present(self):
    """Raises error even if the existing rule of that name is hidden."""
    dataset = cl.ExampleSet()
    dataset.add_hidden_rule('A=B', [])
    with self.assertRaisesRegex(ValueError, 'Rule already present'):
      dataset.add_omitted_rule('A=B')

  def test_add_explicit_rule(self):
    rule0 = 'A=B'
    rule1 = 'C=D'
    dataset = cl.ExampleSet()
    dataset.add_explicit_rule(
        rule0, test_utils.create_example_from_explicit_rule(rule0))
    dataset.add_explicit_rule(
        rule1, test_utils.create_example_from_explicit_rule(rule1))

    with self.subTest(
        name='rules_and_corresponding_examples_should_be_added_to_the_correct_'
        'data_structures_in_ExampleSetMetadata'):
      self.assertLen(dataset, 2)
      self.assertLen(dataset.metadata.rules, 2)
      self.assertLen(dataset.metadata.explicit_rules, 2)
      self.assertEmpty(dataset.metadata.hidden_rules)
      self.assertLen(dataset.metadata.example_indices, 2)

    with self.subTest('should_now_contain_the_rule'):
      self.assertTrue(dataset.metadata.contains_rule(rule0))

    with self.subTest('explicit_rules_should_also_contain_the_rule'):
      self.assertIn(rule0, dataset.metadata.explicit_rules)

    with self.subTest(
        name='should_be_able_to_check_for_membership_in_the_rule_and_example_'
        'sets_using_a_newly_created_but_equivalent_string_or_Example_object'):
      self.assertTrue(dataset.metadata.contains_rule(rule0))
      self.assertIn(
          cl.Example(request=rule0, reply='1'),
          dataset.metadata.example_indices)

    with self.subTest(
        name='should_preserve_the_order_in_which_the_rules_were_added'):
      self.assertEqual(rule0, dataset.metadata.rules[0])

  def test_add_explicit_rule_raises_error_if_rule_already_present(self):
    """Raises error even if the existing rule of that name is hidden."""
    rule = 'A=B'
    dataset = cl.ExampleSet()
    dataset.add_hidden_rule('A=B', [])
    with self.assertRaisesRegex(ValueError, 'Rule already present'):
      dataset.add_explicit_rule(
          rule, test_utils.create_example_from_explicit_rule(rule))

  def test_add_hidden_rule(self):
    dataset = cl.ExampleSet()
    dataset.add_hidden_rule('A=B', [])
    dataset.add_hidden_rule('C=D', [
        cl.Example(request='c1', reply='d1'),
        cl.Example(request='c2', reply='d2')
    ])

    with self.subTest(
        name='rules_and_corresponding_examples_should_be_added_to_the_correct_'
        'data_structures_in_ExampleSetMetadata'):
      self.assertLen(dataset, 2)
      self.assertLen(dataset.metadata.rules, 2)
      self.assertEmpty(dataset.metadata.explicit_rules)
      self.assertLen(dataset.metadata.hidden_rules, 2)
      self.assertLen(dataset.metadata.example_indices, 2)

    with self.subTest('should_now_contain_the_rule'):
      self.assertTrue(dataset.metadata.contains_rule('A=B'))
      self.assertTrue(dataset.metadata.contains_rule('C=D'))

    with self.subTest('hidden_rules_should_also_contain_the_rule'):
      self.assertIn('A=B', dataset.metadata.hidden_rules)
      self.assertIn('C=D', dataset.metadata.hidden_rules)

  def test_add_hidden_rule_raises_error_if_rule_already_present(self):
    """Raises error even if the existing rule of that name is explicit."""
    rule = 'A=B'
    dataset = cl.ExampleSet()
    dataset.add_explicit_rule(
        rule, test_utils.create_example_from_explicit_rule(rule))
    with self.assertRaisesRegex(ValueError, 'Rule already present'):
      dataset.add_hidden_rule(rule, [])

  def test_add_hidden_rule_should_not_unnecessarily_copy_examples(self):
    rule = 'C=D'
    original_example = cl.Example(request='c1', reply='d1')
    dataset = cl.ExampleSet()
    dataset.add_hidden_rule(rule, [original_example])
    self.assertIs(original_example, dataset[0])

  def test_add_hidden_rule_with_existing_illustrative_example(self):
    rule = 'C=D'
    example1 = cl.Example(request='c1', reply='d1')
    example2 = cl.Example(request='c1', reply='d1')
    dataset = cl.ExampleSet.from_examples([example1])
    dataset.add_hidden_rule(rule, [example2])
    with self.subTest('does_not_add_the_duplicate_example'):
      self.assertLen(dataset, 1)
    with self.subTest('does_add_the_rule'):
      self.assertEqual(rule, dataset.metadata.rules[0])
      self.assertTrue(dataset.metadata.contains_rule(rule))

  def test_contains(self):
    example_to_add = cl.Example(request='a', reply='b')
    some_other_example = cl.Example(request='c', reply='d')
    dataset = cl.ExampleSet()
    dataset.add_example(example_to_add)

    with self.subTest('contains_example_that_was_added'):
      self.assertIn(example_to_add, dataset)

    with self.subTest('does_not_contain_example_that_was_not_added'):
      self.assertNotIn(some_other_example, dataset)

  @parameterized.named_parameters(
      ('With examples', 'A=B', [cl.Example(request='a', reply='b')]),
      ('Without examples', 'A=B', []))
  def test_contains_rule_after_adding_hidden_rule(self, rule, examples):
    dataset = cl.ExampleSet()
    dataset.add_hidden_rule(rule, examples)
    with self.subTest('contains_rule_that_was_added'):
      self.assertTrue(dataset.metadata.contains_rule(rule))
    with self.subTest('does_not_contain_rule_that_was_not_added'):
      self.assertFalse(dataset.metadata.contains_rule('Some other rule'))

  @parameterized.named_parameters(
      # Note: It's important to test deepcopy on a mutable context directly, as
      # the creation of a FrozenExampleSet already involved deepcopy internally.
      ('Mutable context',
       _create_mutable_context_exercising_every_field_and_method()),
      ('Top-level example set', _create_example_set()))
  def test_deepcopy_should_return_an_equivalent_example_set(self, original):
    copied = copy.deepcopy(original)
    self.assertEqual(original, copied)
    with self.subTest('equivalent_examples'):
      self.assertEqual(original, copied)
    with self.subTest('equivalent_metadata'):
      self.assertEqual(repr(original.metadata), repr(copied.metadata))

  @parameterized.named_parameters(
      ('Empty', cl.ExampleSet(), False),
      ('Non-empty',
       cl.ExampleSet.from_examples([cl.Example(request='a', reply='b')]), True))
  def test_implicit_bool_conversion(self, dataset, expected_value):
    self.assertEqual(expected_value, bool(dataset))

  @parameterized.named_parameters(
      ('Empty', cl.ExampleSet(), 0),
      ('Non-empty',
       cl.ExampleSet.from_examples([cl.Example(request='a', reply='b')]), 1))
  def test_len(self, dataset, expected_value):
    self.assertLen(dataset, expected_value)

  @parameterized.named_parameters(
      ('Empty', cl.ExampleSet(), '{}', False),
      ('Non-empty',
       cl.ExampleSet.from_examples([
           cl.Example(request='a', reply='b'),
           cl.Example(request='c'),
           cl.Example(reply='d'),
           cl.Example(request='e', reply='f', qualifier=cl.Qualifier.D)
       ]),
       textwrap.dedent("""\
          {<{}, a, b, M>
           <{}, c, , M>
           <{}, , d, M>
           <{}, e, f, D>}"""), False),
      ('Containing example with metadata',
       cl.ExampleSet.from_examples([
           cl.Example(
               request='a',
               reply='b',
               qualifier=cl.Qualifier.D,
               metadata=cl.ExampleMetadata(rules={'A=B'}, target_rule='A=B'))
       ]), textwrap.dedent("""\
          {<{}, a, b, D>}"""), True))
  def test_to_string(self, dataset, expected_string, has_example_metadata):
    frozen = cl.FrozenExampleSet.from_example_set(dataset)

    actual_repr = repr(dataset)
    actual_frozen_repr = repr(frozen)

    with self.subTest('ExampleSet_to_string'):
      self.assertEqual(expected_string, dataset.to_string())

    with self.subTest('ExampleSet_str'):
      self.assertEqual(expected_string, str(dataset))

    with self.subTest('ExampleSet_repr_subsumes_str'):
      self.assertContainsSubsequence(actual_repr, expected_string)

    with self.subTest('ExampleSet_repr_includes_metadata_if_nonempty'):
      if has_example_metadata:
        self.assertContainsExactSubsequence(actual_repr, 'ExampleMetadata')
        self.assertContainsExactSubsequence(actual_repr, 'ExampleSetMetadata')

    with self.subTest('FrozenExampleSet_to_string'):
      self.assertEqual(expected_string, frozen.to_string())

    with self.subTest('FrozenExampleSet_str'):
      self.assertEqual(expected_string, str(frozen))

    with self.subTest('FrozenExampleSet_repr_same_as_non_frozen'):
      self.assertEqual(actual_repr, actual_frozen_repr)

  def test_to_string_nested(self):
    context = cl.FrozenExampleSet.from_examples([
        cl.Example(request='c', reply='d'),
        cl.Example(request='e', reply='f', qualifier=cl.Qualifier.D)
    ])

    dataset = cl.ExampleSet.from_examples([
        cl.Example(request='a', reply='b'),
        cl.Example(context, 'g', 'h'),
        cl.Example(request='i', reply='j', qualifier=cl.Qualifier.D)
    ])

    self.assertEqual(
        textwrap.dedent("""\
        {<{}, a, b, M>
         <{<{}, c, d, M>
           <{}, e, f, D>}, g, h, M>
         <{}, i, j, D>}"""), dataset.to_string())

  def test_unstructured_example_set_is_json_serializable(self):
    dataset = _create_example_set()
    logging.info('Original ExampleSet:\n%s', dataset)
    unstructured = dataset.serialize()
    logging.info('Unstructured ExampleSet:\n%s', unstructured)
    try:
      dataset_as_json = json.dumps(unstructured)
    except TypeError:
      self.fail(f'Exception raised when converting ExampleSet to JSON: '
                f'{traceback.format_exc()}')
    logging.info('ExampleSet as JSON:\n%s', dataset_as_json)

  def test_json_representation_of_example_set_is_dict_of_examples_and_metadata(
      self):
    dataset = cl.ExampleSet.from_examples([cl.Example()])
    unstructured_dataset = dataset.serialize()
    with self.subTest('contains_only_examples_and_metadata'):
      self.assertCountEqual({'_examples', 'metadata'},
                            unstructured_dataset.keys())
    with self.subTest('empty_contexts_are_abbreviated_as_null'):
      # For readability, we prefer null over {"_examples": [], ...}.
      self.assertIsNone(unstructured_dataset['_examples'][0]['context'])

  @parameterized.named_parameters(
      ('simple_context',
       _create_small_mutable_context_with_every_serializable_field()),
      ('more_exhaustive_context',
       _create_mutable_context_exercising_every_field_and_method()),
      ('top_level_example_set_with_train_similarity_metadata',
       _create_example_set(
           train_similarity=cl.ExampleSetTrainSimilarityMetadata(
               nearest_similarity_by_rule_overlap=0.1,
               nearest_similarity_by_example_overlap=0.2,
           ))))
  def test_roundtrip_serialize(self, original):
    logging.info('Original ExampleSet:\n%s', original)
    unstructured = original.serialize()
    logging.info('Unstructured ExampleSet:\n%s', unstructured)
    restored = cl.ExampleSet.deserialize(unstructured)
    logging.info('Restored ExampleSet:\n%s', restored)
    with self.subTest('restores_examples'):
      self.assertEqual(original, restored)
    with self.subTest('restores_metadata'):
      # When comparing long single-line strings, assertSequenceEqual gives more
      # readable output than assertEqual.
      self.maxDiff = None
      self.assertSequenceEqual(repr(original.metadata), repr(restored.metadata))

  def test_roundtrip_serialize_with_rules(self):
    # Add some explicit rules.
    original = cl.ExampleSet()
    original.add_explicit_rule(
        'A=B', test_utils.create_example_from_explicit_rule('A=B'))
    original.add_explicit_rule(
        'C=D', test_utils.create_example_from_explicit_rule('C=D'))

    # Add some hidden rules.
    original.add_hidden_rule('E=F', [])
    original.add_hidden_rule('G=H', [
        cl.Example(
            request='g1',
            reply='h1',
            metadata=cl.ExampleMetadata(rules={'G=H'})),
        cl.Example(
            request='g2',
            reply='h2',
            metadata=cl.ExampleMetadata(rules={'G=H'})),
    ])

    # Add some unreliable rules.
    original.mark_rule_as_unreliable('I=J')
    original.mark_rule_as_unreliable('K=L')
    original.add_unreliable_rule('I=J', [])
    original.add_unreliable_rule('K=L', [
        cl.Example(
            request='k1',
            reply='l1',
            metadata=cl.ExampleMetadata(
                rules={'A=B', 'G=H', 'K=LA'},
                distractor_rules_by_unreliable_rule={'K=L': ['K=LA']})),
        cl.Example(
            request='k2',
            reply='k2',
            metadata=cl.ExampleMetadata(
                rules={'A=B', 'G=H', 'K=LB'},
                distractor_rules_by_unreliable_rule={'K=L': ['K=LB']})),
    ])

    # Add some omitted rules.
    original.add_omitted_rule('M=N')
    original.add_omitted_rule('O=P')

    # Populate rule_reply_by_hidden_rule.
    original.metadata.rule_reply_by_hidden_rule = {
        'E=F': cl.RuleReply.TRUE,
        'G=H': cl.RuleReply.UNKNOWN
    }

    unstructured = original.serialize()
    unstructured_metadata = unstructured['metadata']
    restored = cl.ExampleSet.deserialize(unstructured)

    with self.subTest('restores_examples'):
      self.assertEqual(original, restored)
    with self.subTest('restores_metadata_including_rules'):
      self.assertEqual(
          original.metadata,
          restored.metadata,
          msg=f'\n{original.metadata}\n\n{restored.metadata}')
    with self.subTest('contains_correct_example_count_for_explicit_rules'):
      self.assertEqual(unstructured_metadata['explicit_rules'], {
          'A=B': 3,
          'C=D': 1
      })
    with self.subTest('contains_correct_example_count_for_hidden_rules'):
      self.assertEqual(unstructured_metadata['hidden_rules'], {
          'E=F': 0,
          'G=H': 4
      })
    with self.subTest('example_count_for_unreliable_rules'):
      # Includes examples in which a distractor variant of the rule was used.
      self.assertEqual(unstructured_metadata['unreliable_rules'], {
          'I=J': 0,
          'K=L': 2
      })
    with self.subTest('example_count_for_omitted_rules'):
      self.assertEqual(unstructured_metadata['omitted_rules'], {
          'M=N': 0,
          'O=P': 0
      })
    with self.subTest('example_count_for_distractor_rules'):
      self.assertEqual(unstructured_metadata['distractor_rules'], {
          'K=LA': 1,
          'K=LB': 1
      })


class ExampleSetAddExampleTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.dataset = cl.ExampleSet.from_examples([
        cl.Example(request='a', reply='b'),
        cl.Example(request='c', reply='d')
    ])

  def test_different_examples_can_be_added_successfully(self):
    self.assertLen(self.dataset, 2)

  def test_duplicate_examples_are_ignored(self):
    self.dataset.add_example(cl.Example(request='a', reply='b'))
    self.dataset.add_example(
        cl.Example(request='a', reply='b', qualifier=cl.Qualifier.M))

    self.assertLen(self.dataset, 2)

  def test_example_only_considered_duplicate_if_all_fields_match(self):
    context = cl.FrozenExampleSet.from_examples(
        [cl.Example(request='e', reply='f')])
    self.dataset.add_example(
        cl.Example(request='a', reply='b', context=context))

    self.dataset.add_example(
        cl.Example(request='a', reply='b', qualifier=cl.Qualifier.D))
    self.dataset.add_example(cl.Example(request='a'))
    self.dataset.add_example(cl.Example(reply='b'))

    self.assertLen(self.dataset, 6)


class FrozenExampleSetTest(absltest.TestCase):

  def test_request_already_in_example_set(self):
    seen_request = 'request'
    unseen_request = 'something else'
    example_set = cl.ExampleSet.from_examples(
        [cl.Example(request=seen_request, reply='reply')])
    frozen = cl.FrozenExampleSet.from_example_set(example_set)
    with self.subTest('should_return_False_on_unseen_request'):
      self.assertFalse(frozen.request_already_in_example_set(unseen_request))
    with self.subTest('should_return_True_on_seen_request'):
      self.assertTrue(frozen.request_already_in_example_set(seen_request))

  def test_serialize_should_have_correct_keys(self):
    example_set = cl.ExampleSet.from_examples(
        [cl.Example(request='request', reply='reply')])
    frozen = cl.FrozenExampleSet.from_example_set(example_set)
    unstructured = frozen.serialize()
    expected_keys = ['_examples', 'metadata']
    self.assertCountEqual(unstructured.keys(), expected_keys)


class FrozenExampleSetFromExampleSetTest(absltest.TestCase):

  def test_frozen_copy_should_contain_equivalent_contents_to_original(self):
    dataset = _create_example_set()
    old_dataset = copy.deepcopy(dataset)
    frozen = cl.FrozenExampleSet.from_example_set(dataset)
    with self.subTest('same_examples'):
      self.assertEqual(list(iter(old_dataset)), list(iter(frozen)))
    with self.subTest('same_metadata'):
      self.assertEqual(old_dataset.metadata, frozen.metadata)
    with self.subTest('same_length'):
      self.assertLen(frozen, len(old_dataset))
    with self.subTest('same_implicit_boolean_value'):
      self.assertEqual(bool(old_dataset), bool(frozen))

  def test_freezing_does_not_unnecessarily_copy_examples_basic(self):
    # There is no need to clone Examples, as the Examples are already immutable.
    # Avoiding cloning is important, as otherwise the potentially arbitrarily
    # deep nesting of ExampleSets in each Example's context in conjunction with
    # the multiple references to each Example in the ExampleSetMetadata could
    # lead to an order of magnitude or more Example instances in the
    # FrozenExampleSet compared to the original ExampleSet.
    dataset = cl.ExampleSet.from_examples([cl.Example(request='a', reply='b')])
    frozen = cl.FrozenExampleSet.from_example_set(dataset)
    self.assertIs(dataset[0], frozen[0])

  def test_adding_content_to_original_should_leave_frozen_copy_unchanged(self):
    # Construct an initial ExampleSet.
    omitted_rule1 = 'E=F'
    explicit_rule1 = 'A=B'
    hidden_rule1 = 'C=D'
    unreliable_rule1 = 'K=L'
    distractor_rule1 = 'K=LA'
    example_set = _create_mutable_context_exercising_every_field_and_method(
        omitted_rule=omitted_rule1,
        explicit_rule=explicit_rule1,
        hidden_rule=hidden_rule1,
        unreliable_rule=unreliable_rule1,
        distractor_rule=distractor_rule1)
    old_example_set = copy.deepcopy(example_set)
    frozen = cl.FrozenExampleSet.from_example_set(example_set)
    old_example_set_metadata_string = repr(old_example_set.metadata)

    # Add some additional content using a variety of methods.
    # The intention is to cover all of the different rule types and possible
    # ways of adding examples, so as to touch all of the different metadata
    # fields and substructures.

    # Examples using new rules.
    explicit_rule2 = 'G=H'
    example_set.add_explicit_rule(
        explicit_rule2,
        test_utils.create_example_from_explicit_rule(explicit_rule2))
    example_set.add_hidden_rule('I=J', [
        cl.Example(request='g1', reply='h1'),
        cl.Example(request='g2', reply='h2')
    ])
    unreliable_rule2 = 'M=N'
    distractor_rule2 = 'M=NA'
    example_set.mark_rule_as_unreliable(unreliable_rule2)
    example_set.add_unreliable_rule(unreliable_rule2, [
        cl.Example(
            request='m1',
            reply='n1',
            metadata=cl.ExampleMetadata(
                rules={distractor_rule2},
                distractor_rules_by_unreliable_rule={
                    unreliable_rule2: [distractor_rule2]
                })),
    ])

    # Example using existing rules.
    example_set.add_example(
        cl.Example(
            request='e2',
            reply='f2',
            metadata=cl.ExampleMetadata(
                rules={
                    omitted_rule1, explicit_rule1, hidden_rule1,
                    distractor_rule1
                },
                distractor_rules_by_unreliable_rule={
                    unreliable_rule1: [distractor_rule1, distractor_rule2]
                })))

    with self.subTest(
        name='frozen_copy_equivalent_to_the_original_at_time_of_freezing'):
      self.assertEqual(list(iter(old_example_set)), list(iter(frozen)))
      self.assertEqual(old_example_set.metadata, frozen.metadata)
      # Note that as an extra precaution we are checking here against the string
      # representation of the original metadata, rather than just the deepcopied
      # metadata object, to protect against any possible bugs in the
      # ExampleSet.__deepcopy__ implementation.
      self.assertEqual(old_example_set_metadata_string, repr(frozen.metadata))
    with self.subTest(
        name='frozen_copy_not_equivalent_to_the_modified_version'):
      self.assertNotEqual(list(iter(example_set)), list(iter(frozen)))
      self.assertNotEqual(example_set.metadata, frozen.metadata)


class ExampleGroupTest(absltest.TestCase):

  def test_add_example(self):
    example_group = cl.ExampleGroup()
    example = cl.Example(request='a', reply='b')
    returned_example = example_group.add_example(example)
    with self.subTest('Does not unnecessarily copy the example.'):
      self.assertIs(example, example_group[0])
    with self.subTest('Returns the newly inserted example.'):
      self.assertIs(example, returned_example)

  def test_add_example_duplicate_example(self):
    example1 = cl.Example(request='a', reply='b')
    example2 = cl.Example(request='c', reply='d')
    example3 = cl.Example(request='a', reply='b')
    example_group = cl.ExampleGroup(examples=[example1, example2])
    returned_value = example_group.add_example(example3)
    with self.subTest('Does not add the duplicate example.'):
      self.assertLen(example_group, 2)
    with self.subTest('Returns None'):
      self.assertIsNone(returned_value)

  def test_serialize_should_have_correct_keys(self):
    example_group = _create_example_group(with_context=True)
    unstructured = example_group.serialize()
    expected_keys = ['context', '_examples']
    expected_context_keys = ['_examples', 'metadata']
    self.assertCountEqual(unstructured.keys(), expected_keys)
    self.assertCountEqual(unstructured['context'].keys(), expected_context_keys)

  def test_roundtrip_deserialize(self):
    example_group = _create_example_group(with_context=True)
    unstructured_example_group = example_group.serialize()
    restored = cl.ExampleGroup.deserialize(unstructured_example_group)

    self.assertEqual(example_group, restored)

  def test_to_string(self):
    context = cl.FrozenExampleSet.from_examples([
        cl.Example(
            request='c',
            reply='d',
            metadata=cl.ExampleMetadata(rules=set(['A=B']), target_rule='A=B')),
    ])
    example_group = cl.ExampleGroup(
        context=context,
        examples=[
            cl.Example(
                request='a',
                reply='b',
                metadata=cl.ExampleMetadata(rules=set(['A=B']))),
        ])

    expected_string = textwrap.dedent("""\
      ({<{}, c, d, M>}
       {<{}, a, b, M>})""")

    actual_repr = repr(example_group)

    with self.subTest('ExampleGroup_to_string'):
      self.assertEqual(expected_string, example_group.to_string())

    with self.subTest('ExampleGroup_str'):
      self.assertEqual(expected_string, str(example_group))

    with self.subTest('ExampleGroup_repr_subsumes_str'):
      self.assertContainsSubsequence(actual_repr, expected_string)

    with self.subTest('ExampleGroup_repr_includes_metadata'):
      self.assertContainsExactSubsequence(actual_repr, 'ExampleMetadata')
      self.assertContainsExactSubsequence(actual_repr, 'ExampleSetMetadata')


class GroupedExampleSetTest(parameterized.TestCase):

  def test_from_example_set_should_raise_if_nested_context(self):
    nested_context = cl.FrozenExampleSet.from_examples(
        [cl.Example(request='a', reply='b')])
    context = cl.FrozenExampleSet.from_examples(
        [cl.Example(context=nested_context, request='c', reply='d')])
    example_set = cl.ExampleSet.from_examples(
        [cl.Example(context=context, request='e', reply='f')])

    with self.assertRaisesRegex(
        ValueError, 'ExampleSet with nested context cannot be converted to the '
        'GroupedExampleSet format.'):
      cl.GroupedExampleSet.from_example_set(example_set)

  def test_from_example_set(self):
    example_set = _create_example_set()
    original_contexts = set([example.context for example in example_set])

    grouped_example_set = cl.GroupedExampleSet.from_example_set(example_set)
    contexts = set([
        example_group.context
        for example_group in grouped_example_set.example_groups
    ])

    restored = grouped_example_set.to_example_set()

    with self.subTest('should_keep_all_the_contexts'):
      self.assertEqual(original_contexts, contexts)
    with self.subTest(
        name='roundtrip_should_restore_original_example_set_up_to_'
        'reordering_of_the_examples'):
      self.assertCountEqual(restored, example_set)

  def test_to_example_set(self):
    grouped_example_set = _create_grouped_example_set()
    original_contexts = set([
        example_group.context
        for example_group in grouped_example_set.example_groups
    ])

    example_set = grouped_example_set.to_example_set()
    contexts = set([example.context for example in example_set])

    restored = cl.GroupedExampleSet.from_example_set(example_set)

    with self.subTest('should_keep_all_the_contexts'):
      self.assertEqual(original_contexts, contexts)
    with self.subTest(
        name='roundtrip_should_restore_original_example_set_exactly'):
      self.assertEqual(restored, grouped_example_set)

  def test_serialize_should_have_correct_keys(self):
    grouped_example_set = _create_grouped_example_set()
    unstructured = grouped_example_set.serialize()
    expected_keys = ['example_groups']
    self.assertCountEqual(unstructured.keys(), expected_keys)

  def test_roundtrip_deserialize(self):
    grouped_example_set = _create_grouped_example_set()
    unstructured_grouped_example_set = grouped_example_set.serialize()
    restored = cl.GroupedExampleSet.deserialize(
        unstructured_grouped_example_set)

    self.assertEqual(grouped_example_set, restored)

  @parameterized.named_parameters(
      ('Empty', cl.ExampleSet(), '{}', False),
      ('Non-empty',
       cl.ExampleSet.from_examples([
           cl.Example(request='a', reply='b'),
           cl.Example(request='c'),
           cl.Example(reply='d'),
           cl.Example(request='e', reply='f', qualifier=cl.Qualifier.D)
       ]),
       textwrap.dedent("""\
          {({}
            {<{}, a, b, M>
             <{}, c, , M>
             <{}, , d, M>
             <{}, e, f, D>})}"""), False),
      ('Containing example with metadata',
       cl.ExampleSet.from_examples([
           cl.Example(
               request='a',
               reply='b',
               qualifier=cl.Qualifier.D,
               metadata=cl.ExampleMetadata(rules={'A=B'}, target_rule='A=B'))
       ]), textwrap.dedent("""\
          {({}
            {<{}, a, b, D>})}"""), True))
  def test_to_string(self, dataset, expected_string, has_example_metadata):
    grouped_example_set = cl.GroupedExampleSet.from_example_set(dataset)

    actual_repr = repr(grouped_example_set)

    with self.subTest('GroupedExampleSet_to_string'):
      self.assertEqual(expected_string, grouped_example_set.to_string())

    with self.subTest('GroupedExampleSet_str'):
      self.assertEqual(expected_string, str(grouped_example_set))

    with self.subTest('GroupedExampleSet_repr_subsumes_str'):
      self.assertContainsSubsequence(actual_repr, expected_string)

    with self.subTest('GroupedExampleSet_repr_includes_example_metadata'):
      # For the moment the repr of a GroupedExampleSet still does not contain
      # ExampleSetMetadata (unlike the repr of an ExampleSet). It does, however,
      # include the ExampleMetadata.
      if has_example_metadata:
        self.assertContainsExactSubsequence(actual_repr, 'ExampleMetadata')


if __name__ == '__main__':
  absltest.main()
