# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_spec_loader
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.tools.metrics import consistency_metric


def _create_example(request,
                    reply,
                    qualifier,
                    context = None,
                    production_string = None):
  if context is None:
    context = cl.FrozenExampleSet()
  if production_string is None:
    production = None
  else:
    production = nltk_utils.production_from_production_string(production_string)
  return cl.Example(
      context=context,
      request=request,
      reply=reply,
      qualifier=qualifier,
      metadata=cl.ExampleMetadata(production=production))


def _get_3_test_contexts():
  context_1 = cl.FrozenExampleSet.from_examples([
      _create_example(
          request='a',
          reply='1',
          production_string="U[sem='1'] -> 'a'",
          qualifier=cl.Qualifier.M)
  ])
  context_2 = cl.FrozenExampleSet.from_examples([
      _create_example(
          request='b',
          reply='2',
          production_string="U[sem='2'] -> 'b'",
          qualifier=cl.Qualifier.M)
  ])
  context_3 = cl.FrozenExampleSet.from_examples([
      _create_example(
          request='c',
          reply='3',
          production_string="U[sem='3'] -> 'c'",
          qualifier=cl.Qualifier.M)
  ])
  return context_1, context_2, context_3


def _get_5_semi_consistent_examples(context):
  """Returns a list of examples with equal implied and contradicted examples."""
  return [
      _create_example(
          context=context,
          request='[walk] = WALK',
          reply=cl.RuleReply.TRUE,
          production_string="U[sem='WALK'] -> 'walk'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='[run] = RUN',
          reply=cl.RuleReply.TRUE,
          production_string="U[sem='RUN'] -> 'run'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='[x twice] = [x] [x]',
          reply=cl.RuleReply.TRUE,
          production_string="S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'",
          qualifier=cl.Qualifier.M),
      # This example is implied by the rule examples above.
      _create_example(
          context=context,
          request='walk twice',
          reply='WALK WALK',
          production_string="S[sem=('WALK'+'WALK')] -> 'walk' 'twice'",
          qualifier=cl.Qualifier.M),
      # This example is contradicted by the rule examples above.
      _create_example(
          context=context,
          request='run twice',
          reply='RUN RUN RUN',
          production_string="S[sem=('RUN'+'RUN'+'RUN')] -> 'run' 'twice'",
          qualifier=cl.Qualifier.M),
  ]


def _get_5_consistent_examples(context):
  """Returns a consistent list of examples."""
  return [
      _create_example(
          context=context,
          request='[walk] = WALK',
          reply=cl.RuleReply.TRUE,
          production_string="U[sem='WALK'] -> 'walk'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='[run] = RUN',
          reply=cl.RuleReply.TRUE,
          production_string="U[sem='RUN'] -> 'run'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='[x twice] = [x] [x]',
          reply=cl.RuleReply.TRUE,
          production_string="S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'",
          qualifier=cl.Qualifier.M),
      # These two examples are both implied by the rule examples above.
      _create_example(
          context=context,
          request='walk twice',
          reply='WALK WALK',
          production_string="S[sem=('WALK'+'WALK')] -> 'walk' 'twice'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='run twice',
          reply='RUN RUN',
          production_string="S[sem=('RUN'+'RUN')] -> 'run' 'twice'",
          qualifier=cl.Qualifier.M)
  ]


def _get_5_inconsistent_examples(context):
  """Returns an inconsistent list of examples with no implications."""
  return [
      _create_example(
          context=context,
          request='[walk] = WALK',
          reply=cl.RuleReply.TRUE,
          production_string="U[sem='WALK'] -> 'walk'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='[run] = RUN',
          reply=cl.RuleReply.TRUE,
          production_string="U[sem='RUN'] -> 'run'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='[x twice] = [x] [x]',
          reply=cl.RuleReply.TRUE,
          production_string="S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'",
          qualifier=cl.Qualifier.M),
      # These two examples are both contradicted by the rule examples above.
      _create_example(
          context=context,
          request='walk twice',
          reply='WALK WALK WALK',
          production_string="S[sem=('WALK'+'WALK'+'WALK')] -> 'walk' 'twice'",
          qualifier=cl.Qualifier.M),
      _create_example(
          context=context,
          request='run twice',
          reply='RUN RUN RUN',
          production_string="S[sem=('RUN'+'RUN'+'RUN')] -> 'run' 'twice'",
          qualifier=cl.Qualifier.M),
  ]


class CandidateImplicationTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('defeasible_set_defeasible_example',
       [cl.Qualifier.D, cl.Qualifier.D, cl.Qualifier.M], cl.Qualifier.D, False),
      ('all_defeasible_set_defeasible_example',
       [cl.Qualifier.D, cl.Qualifier.D, cl.Qualifier.D], cl.Qualifier.D, False),
      ('monotonic_set_defeasible_example',
       [cl.Qualifier.M, cl.Qualifier.M, cl.Qualifier.M], cl.Qualifier.D, True),
      ('all_defeasible_set_monotonic_example',
       [cl.Qualifier.D, cl.Qualifier.D, cl.Qualifier.D], cl.Qualifier.M, False),
      ('monotonic_set_monotonic_example',
       [cl.Qualifier.M, cl.Qualifier.M, cl.Qualifier.M], cl.Qualifier.M, False))
  def test_is_qualifier_contradiction(self, source_qualifiers,
                                      example_qualifier, expected_value):
    source_examples = []
    for i, qualifier in enumerate(source_qualifiers):
      example = _create_example(
          request=f'request_{i}',
          reply=f'reply_{i}',
          qualifier=qualifier)
      source_examples.append(example)
    example = _create_example(request='example_request',
                              reply='example_reply',
                              qualifier=example_qualifier)
    candidate_implication = consistency_metric.CandidateImplication(
        example=example, source_examples=frozenset(source_examples))
    is_qualifier_contradiction = (
        candidate_implication.is_qualifier_contradiction())
    self.assertEqual(is_qualifier_contradiction, expected_value)

  @parameterized.named_parameters(
      ('defeasible_set_defeasible_example',
       [cl.Qualifier.D, cl.Qualifier.D, cl.Qualifier.M], cl.Qualifier.D, True),
      ('all_defeasible_set_defeasible_example',
       [cl.Qualifier.D, cl.Qualifier.D, cl.Qualifier.D], cl.Qualifier.D, True),
      ('monotonic_set_defeasible_example',
       [cl.Qualifier.M, cl.Qualifier.M, cl.Qualifier.M], cl.Qualifier.D, False),
      ('all_defeasible_set_monotonic_example',
       [cl.Qualifier.D, cl.Qualifier.D, cl.Qualifier.D], cl.Qualifier.M, False),
      ('monotonic_set_monotonic_example',
       [cl.Qualifier.M, cl.Qualifier.M, cl.Qualifier.M], cl.Qualifier.M, True))
  def test_is_implication(self, source_qualifiers, example_qualifier,
                          expected_value):
    source_examples = []
    for i, qualifier in enumerate(source_qualifiers):
      example = _create_example(
          request=f'request_{i}',
          reply=f'reply_{i}',
          qualifier=qualifier)
      source_examples.append(example)
    example = _create_example(request='example_request',
                              reply='example_reply',
                              qualifier=example_qualifier)
    candidate_implication = consistency_metric.CandidateImplication(
        example=example, source_examples=frozenset(source_examples))
    is_implication = candidate_implication.is_implication()
    self.assertEqual(is_implication, expected_value)


class ConsistencyAggregationTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('no_implications_or_contradictions', 0, 0, -1),
      ('no_contradictions', 1, 0, 1),
      ('no_implications', 0, 1, 0),
      ('more_implications_than_contradictions', 9, 1, 0.9),
      ('more_contradictions_than_implications', 1, 9, 0.1))
  def test_example_group_consistency(self, implications, contradictions,
                                     expected_consistency):
    group_consistency = consistency_metric.ExampleGroupConsistency(
        context_id='md5_hash',
        implications=implications,
        contradictions=contradictions)
    consistency = group_consistency.consistency()
    self.assertEqual(consistency, expected_consistency)

  @parameterized.named_parameters(
      ('general_test', [(0, 0), (0, 1), (89, 9), (1, 0)], 0.9, 90, 10),
      ('no_implications_or_contradictions', [(0, 0), (0, 0)], -1, 0, 0),
      ('no_implications', [(0, 5), (0, 5)], 0, 0, 10),
      ('no_contradictions', [(5, 0), (5, 0)], 1, 10, 0))
  def test_example_set_consistency(self, groups_consistencies,
                                   expected_consistency, expected_implications,
                                   expected_contradictions):
    example_groups_consistencies = []
    for i, (implications, contradictions) in enumerate(groups_consistencies):
      group_consistency = consistency_metric.ExampleGroupConsistency(
          context_id=f'md5_hash_{i}',
          implications=implications,
          contradictions=contradictions)
      example_groups_consistencies.append(group_consistency)
    example_set_consistency = consistency_metric.ExampleSetConsistency(
        example_groups_consistencies)
    with self.subTest('must_calculate_consistency_correctly'):
      self.assertAlmostEqual(example_set_consistency.consistency(),
                             expected_consistency)
    with self.subTest('must_sum_implications_correctly'):
      self.assertAlmostEqual(example_set_consistency.implications(),
                             expected_implications)
    with self.subTest('must_sum_contradictions_correctly'):
      self.assertAlmostEqual(example_set_consistency.contradictions(),
                             expected_contradictions)


class ConsistencyUtilsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('single_sem_feature', "U[sem='A'] -> 'a'", False, 'B',
       "U[sem='B'] -> 'a'"),
      ('empty_prediction', "U[sem='A'] -> 'a'", False, '',
       "U[sem=''] -> 'a'"),
      ('single_sem_feature_parenthesis', "U[sem=('A')] -> 'a'", False, 'B',
       "U[sem='B'] -> 'a'"),
      ('multi_sem_features', "U[sem=('A'+'B')] -> 'a'", False, 'B',
       "U[sem='B'] -> 'a'"),
      ('multi_reply_sem_features', "U[sem='A'] -> 'a'", False, 'B A',
       "U[sem=('B'+'A')] -> 'a'"),
      ('multi_sem_and_reply_sem_features', "U[sem=('A'+'D')] -> 'a'", False,
       'B D', "U[sem=('B'+'D')] -> 'a'"),
      ('rule_request', "U[sem='A'] -> 'a'", True, cl.RuleReply.TRUE,
       "U[sem='A'] -> 'a'"),
      ('contain_variables', "V[sem=('A'+?x1)] -> 'a' U[sem=?x1]", False,
       '?x1 A ?x1', "V[sem=(?x1, A, ?x1)] -> 'a' U[sem=?x1]"),
      ('rule_request_contains_variables', "V[sem=('A'+?x1)] -> 'a' U[sem=?x1]",
       True, cl.RuleReply.TRUE, "V[sem=('A'+?x1)] -> 'a' U[sem=?x1]"),
  )
  def test_get_production_from_example_and_prediction(
      self, production_string, is_rule_request,
      prediction_reply, expected_production_string):
    if is_rule_request:
      reply = cl.RuleReply.TRUE
    else:
      reply = ''
    example = _create_example(
        request='',
        reply=reply,
        qualifier=cl.Qualifier.M,
        production_string=production_string)
    expected_production = nltk_utils.production_from_production_string(
        expected_production_string)
    production = consistency_metric._get_production_from_example_and_prediction(
        example, prediction_reply)
    self.assertEqual(production, expected_production)

  @parameterized.named_parameters(('scan_finite_nye_standardized', 'base', {
      'C[sem=?x1] -> S[sem=?x1]', 'S[sem=?x1] -> V[sem=?x1]',
      'V[sem=?x1] -> D[sem=?x1]', 'D[sem=?x1] -> U[sem=?x1]'
  }), ('scan_extended', 'extended', {
      'E[sem=?x1] -> C[sem=?x1]', 'C[sem=?x1] -> S[sem=?x1]',
      'S[sem=?x1] -> V[sem=?x1]', 'V[sem=?x1] -> D[sem=?x1]',
      'D[sem=?x1] -> U[sem=?x1]'
  }))
  def test_get_base_inference_engine(self, dataset_spec_id,
                                     expected_productions):
    expected_productions = set(
        map(nltk_utils.production_from_production_string, expected_productions))
    dataset_spec = dataset_spec_loader.load_dataset_spec(dataset_spec_id)
    inference_engine = consistency_metric._get_base_inference_engine(
        dataset_spec)
    self.assertCountEqual(expected_productions,
                          inference_engine.source_productions)

  def test_merge_examples_and_predictions(self):
    dataset = cl.ExampleSet()
    context_1, context_2, context_3 = _get_3_test_contexts()
    dataset.add_example(
        _create_example(
            context=context_1,
            request='a',
            reply='A',
            qualifier=cl.Qualifier.M,
            production_string="U[sem='A'] -> 'a'"))
    dataset.add_example(
        _create_example(
            context=context_2,
            request='b',
            reply='B',
            qualifier=cl.Qualifier.M,
            production_string="U[sem='B'] -> 'b'"))
    dataset.add_example(
        _create_example(
            context=context_3,
            request='[c] = C',
            reply=cl.RuleReply.TRUE,
            qualifier=cl.Qualifier.M,
            production_string="U[sem='C'] -> 'c'"))
    predictions = {
        dataset[0].get_md5_hash(): (cl.Qualifier.D, 'PA'),
        dataset[1].get_md5_hash(): (cl.Qualifier.D, 'PB'),
        dataset[2].get_md5_hash(): (cl.Qualifier.D, cl.RuleReply.TRUE),
    }
    prediction_set = consistency_metric._merge_examples_and_predictions(
        dataset, predictions)
    self.assertLen(prediction_set, 3)
    with self.subTest('must_change_the_reply_to_the_predicted_reply'):
      for idx, prediction_example in enumerate(prediction_set):
        _, expected_reply = predictions[dataset[idx].get_md5_hash()]
        self.assertEqual(prediction_example.reply, expected_reply)

    with self.subTest('must_change_the_qualifier_to_the_predicted_qualifier'):
      for idx, prediction_example in enumerate(prediction_set):
        expected_qualifier, _ = predictions[dataset[idx].get_md5_hash()]
        self.assertEqual(prediction_example.qualifier, expected_qualifier)

    expected_productions_strings = [
        "U[sem='PA'] -> 'a'", "U[sem='PB'] -> 'b'", "U[sem='C'] -> 'c'"
    ]
    expected_productions = [
        nltk_utils.production_from_production_string(production_string)
        for production_string in expected_productions_strings
    ]
    productions = [example.metadata.production for example in prediction_set]
    with self.subTest('must_update_the_productions_based_on_the_predictions'):
      self.assertListEqual(expected_productions, productions)

  @parameterized.named_parameters(
      ('known_step_validation', 10, True, 'eval_consistency10.json'),
      ('unknown_step_validation', -1, True, 'eval_consistency.json'),
      ('known_step_test', 10, False, 'test_consistency10.json'),
      ('unknown_step_test', -1, False, 'test_consistency.json'))
  def test_get_consistency_file_name(self, prediction_step, is_validation,
                                     expected_name):
    consistency_file_name = consistency_metric.get_consistency_file_name(
        prediction_step=prediction_step, is_validation=is_validation)
    self.assertEqual(consistency_file_name, expected_name)


class ConsistencyMetricTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('sets_consistency', False, 5, 8, 5 / (5 + 8)),
      ('examples_consistency', True, 4, 10, 4 / (4+10)))
  def test_consistency_for_example_group(self, example_level_consistency,
                                         expected_implications,
                                         expected_contradictions,
                                         expected_consistency):
    # a: <C, [walk] = WALK, 1, M>: Impl(e), Cont(c, g), Cont(b, d, k),
    # b: <C, [jump] = JUMP, 1, M>: Cont(a, d, k), Cont(e, d, k), Cont(d, g, l)
    # c: <C, [x twice] = [x] [x], 1, M>: Cont(a, g)
    # d: <C, [x1 after x2] = [x2] [x1], 1, M>: Cont(a, b, k), Cont(e, b, k)
    #  Cont(b, g, l)
    # e: <C, walk, WALK, M> : Impl(a), Cont(c, g), Cont(b, d, k),
    # f: <C, eat twice, EAT EAT, M>: -
    # g: <C, walk twice, WALK WALK WALK WALK, M>: Cont(a, c), Cont(b, d, l)
    # h: <C, jump twice, JUMP JUMP, D>: QCont(b, c)
    # i: <C, walk after walk, WALK WALK, M>: Impl(a, d), Impl(e, d)
    # j: <C, jump after walk, WALK JUMP, D>: QCont(a, b, d), QCont(e, b, d)
    # k: <C, walk after jump, WALK JUMP, M>: Cont(a, b, d), Cont(e, b, d)
    # l: <C, walk twice after jump, JUMP WALK WALK, M>: Impl(a, b, c, d),
    #  Impl(e, b, c, d), Cont(b, d, g)
    # In these examples the implied sets are:
    # {a, e}, {a, d, i}, {e, d, i}, {a, b, c, d, l}, {e, b, c, d, l} = 5
    # While the contradicted sets are:
    # {a, c, g}, {e, c, g}, {a, b, d, k}, {e, b, d, k}, {b, d, g, l}, {b, c, h},
    # {a, b, d, j}, {e, b, d, j} = 8
    examples = [
        _create_example(
            request='[walk] = WALK',
            reply=cl.RuleReply.TRUE,
            production_string="U[sem='WALK'] -> 'walk'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='[jump] = JUMP',
            reply=cl.RuleReply.TRUE,
            production_string="U[sem='JUMP'] -> 'jump'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='[x twice] = [x] [x]',
            reply=cl.RuleReply.TRUE,
            production_string="S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='[x1 after x2] = [x2] [x1]',
            reply=cl.RuleReply.TRUE,
            production_string=(
                "C[sem=(?x2+?x1)] -> S[sem=?x1] 'after' S[sem=?x2]"),
            qualifier=cl.Qualifier.M),
        _create_example(
            request='walk',
            reply='WALK',
            production_string="U[sem='WALK'] -> 'walk'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='eat twice',
            reply='EAT EAT',
            production_string="S[sem=('EAT'+'EAT')] -> 'eat' 'twice'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='walk twice',
            reply='WALK WALK WALK WALK',
            production_string=(
                "S[sem=('WALK'+'WALK'+'WALK'+'WALK')] -> 'walk' 'twice'"),
            qualifier=cl.Qualifier.M),
        _create_example(
            request='jump twice',
            reply='JUMP JUMP',
            production_string="S[sem=('JUMP'+'JUMP')] -> 'jump' 'twice'",
            qualifier=cl.Qualifier.D),
        _create_example(
            request='walk after walk',
            reply='WALK WALK',
            production_string=(
                "C[sem=('WALK'+'WALK')] -> 'walk' 'after' 'walk'"),
            qualifier=cl.Qualifier.M),
        _create_example(
            request='jump after walk',
            reply='WALK JUMP',
            production_string=(
                "C[sem=('WALK'+'JUMP')] -> 'jump' 'after' 'walk'"),
            qualifier=cl.Qualifier.D),
        _create_example(
            request='walk after jump',
            reply='WALK JUMP',
            production_string=(
                "C[sem=('WALK'+'JUMP')] -> 'walk' 'after' 'jump'"),
            qualifier=cl.Qualifier.M),
        _create_example(
            request='walk twice after jump',
            reply='JUMP WALK WALK',
            production_string=(
                "C[sem=('JUMP'+'WALK'+'WALK')] -> 'walk' 'twice' 'after' 'jump'"
            ),
            qualifier=cl.Qualifier.M)
    ]
    inference_engine = consistency_metric._get_base_inference_engine(
        dataset_spec_loader.load_dataset_spec('base'))

    consistency = consistency_metric.consistency_for_example_group(
        examples,
        context_id='md5_hash',
        inference_engine=inference_engine,
        example_level_consistency=example_level_consistency)

    with self.subTest('should_detect_all_implied_examples'):
      self.assertEqual(consistency.implications, expected_implications)

    with self.subTest('should_detect_all_contradicted_examples'):
      self.assertEqual(consistency.contradictions, expected_contradictions)

    with self.subTest('should_calculate_the_consistency_correctly'):
      self.assertAlmostEqual(consistency.consistency(), expected_consistency)

  @parameterized.named_parameters(
      ('sets_consistency', False, 1, 2, 1 / (1 + 2)),
      ('examples_consistency', True, 2, 4, 2 / (2 + 4)))
  def test_consistency_for_example_group_with_multiple_inconsistency_sources(
      self, example_level_consistency, expected_implications,
      expected_contradictions, expected_consistency):
    # a: <C, [x twice] = [x] [x], 1, M>: Cont(b, d), Cont(c, d)
    # b: <C, [jump] = JUMP, 1, M>: Impl(c), Cont(a, d)
    # c: <C, jump, JUMP, M>: Impl(c), Cont(a, d)
    # d: <C, jump twice, JUMP JUMP JUMP, M>: Cont(a, b), Cont(a, c)
    # In these examples the implied sets are:
    # {b, c} = 1
    # While the contradcited sets are
    # {a, b, d}, {a, c, d} = 2
    examples = [
        _create_example(
            request='[x twice] = [x] [x]',
            reply=cl.RuleReply.TRUE,
            production_string="S[sem=(?x1+?x1)] -> V[sem=?x1] 'twice'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='[jump] = JUMP',
            reply=cl.RuleReply.TRUE,
            production_string="U[sem='JUMP'] -> 'jump'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='jump',
            reply='JUMP',
            production_string="U[sem='JUMP'] -> 'jump'",
            qualifier=cl.Qualifier.M),
        _create_example(
            request='jump twice',
            reply='JUMP JUMP JUMP',
            production_string="S[sem=('JUMP'+'JUMP'+'JUMP')] -> 'jump' 'twice'",
            qualifier=cl.Qualifier.M),
    ]
    inference_engine = consistency_metric._get_base_inference_engine(
        dataset_spec_loader.load_dataset_spec('base'))

    consistency = consistency_metric.consistency_for_example_group(
        examples, context_id='md5_hash', inference_engine=inference_engine,
        example_level_consistency=example_level_consistency)

    with self.subTest('should_detect_all_implied_examples'):
      self.assertEqual(consistency.implications, expected_implications)

    with self.subTest('should_detect_all_contradicted_examples'):
      self.assertEqual(consistency.contradictions, expected_contradictions)

    with self.subTest('should_calculate_the_consistency_correctly'):
      self.assertAlmostEqual(consistency.consistency(), expected_consistency)

  @parameterized.named_parameters(
      ('sets_consistency', False, 3, 3, 3 / (3 + 3)),
      ('examples_consistency', True, 3, 8, 3 / (3 + 8)))
  def test_consistency_for_example_groups(self, example_level_consistency,
                                          expected_implications,
                                          expected_contradictions,
                                          expected_consistency):
    context_1, context_2, context_3 = _get_3_test_contexts()
    # Examples level: Implications: 1, Contradictions: 3
    # Sets level: Implications: 1, Contradictions: 1
    semi_consistent_set = _get_5_semi_consistent_examples(context_1)
    # Examples level: Implications: 2, Contradictions: 0
    # Sets level: Implications: 2, Contradictions: 0
    consistent_set = _get_5_consistent_examples(context_2)
    # Examples level: Implications: 0, Contradictions: 5
    # Sets level: Implications: 0, Contradictions: 2
    no_implication_set = _get_5_inconsistent_examples(context_3)

    example_groups = [consistent_set, semi_consistent_set, no_implication_set]
    example_set_consistency = (
        consistency_metric.consistency_for_example_groups(
            example_groups,
            dataset_spec=dataset_spec_loader.load_dataset_spec('base'),
            example_level_consistency=example_level_consistency))

    with self.subTest('should_detect_all_implied_examples'):
      self.assertEqual(example_set_consistency.implications(),
                       expected_implications)

    with self.subTest('should_detect_all_contradicted_examples'):
      self.assertEqual(example_set_consistency.contradictions(),
                       expected_contradictions)

    with self.subTest('should_calculate_consistency_correctly'):
      self.assertEqual(example_set_consistency.consistency(),
                       expected_consistency)

  @parameterized.named_parameters(
      ('sets_consistency', False, 3, 3, 3 / (3 + 3)),
      ('examples_consistency', True, 3, 8, 3 / (3 + 8)))
  def test_compute_consistency_for_model_predictions(self,
                                                     example_level_consistency,
                                                     expected_implications,
                                                     expected_contradictions,
                                                     expected_consistency):
    # To test this method we'll first create a consistent example set using the
    # method _get_5_consistent_examples() 3 times each time with different
    # context.
    context_1, context_2, context_3 = _get_3_test_contexts()
    dataset_examples = [
        *_get_5_consistent_examples(context_1),
        *_get_5_consistent_examples(context_2),
        *_get_5_consistent_examples(context_3)
    ]
    # Convert them to a cl.ExampleSet object.
    dataset = cl.ExampleSet()
    for example in dataset_examples:
      dataset.add_example(example)
    # Then convert the cl.ExampleSet object to a cl.GroupedExampleSet.
    grouped_example_set = cl.GroupedExampleSet.from_example_set(dataset)
    # The predictions will consist of 3 example groups as well constructed
    # from the methods _get_5_consistent_examples,
    # _get_5_semi_consistent_examples and _get_5_inconsistent_examples in that
    # order. The idea is that the original dataset is consistent while the
    # predictions are consistent in the first group, contains some implications
    # and contradictions in the second group and just contradictions in the
    # third.
    prediction_examples = [
        *_get_5_consistent_examples(context_1),
        *_get_5_semi_consistent_examples(context_2),
        *_get_5_inconsistent_examples(context_3)
    ]

    predictions_by_md5_hash = {}
    for example, prediction_example in zip(dataset, prediction_examples):
      predictions_by_md5_hash[example.get_md5_hash()] = (
          prediction_example.qualifier, prediction_example.reply)

    example_set_consistency = (
        consistency_metric.compute_consistency_for_model_predictions(
            grouped_example_set,
            predictions_by_md5_hash,
            dataset_spec=dataset_spec_loader.load_dataset_spec('base'),
            example_level_consistency=example_level_consistency))

    with self.subTest('must_detect_all_implications'):
      self.assertEqual(example_set_consistency.implications(),
                       expected_implications)

    with self.subTest('must_detect_all_contradictions'):
      self.assertEqual(example_set_consistency.contradictions(),
                       expected_contradictions)

    with self.subTest('must_compute_the_consistency_correctly'):
      self.assertEqual(example_set_consistency.consistency(),
                       expected_consistency)


if __name__ == '__main__':
  absltest.main()
