# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import dataclasses
import traceback
from typing import Any

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import production_trees
from conceptual_learning.cscan import test_utils


def _assert_counter_equal(test_case, actual,
                          expected):
  """Helper function to compare counter-like objects with detailed message."""
  test_case.assertEqual(actual, expected, f'Actual counters = {actual}')


class DistributionSummaryStatsTest(absltest.TestCase):

  def test_initial_values(self):
    stats = outputs.DistributionSummaryStats()

    with self.subTest('mean_should_initially_be_zero'):
      self.assertEqual(0.0, stats.mean)

    with self.subTest('stddev_should_initially_be_zero'):
      self.assertEqual(0.0, stats.stddev)

  def test_update_with_single_value(self):
    stats = outputs.DistributionSummaryStats()
    stats.update_with_value(1.0)

    with self.subTest('mean_should_equal_the_value_added'):
      self.assertEqual(1.0, stats.mean)

    with self.subTest(
        'stddev_should_still_be_zero_until_there_are_two_or_more_values'):
      self.assertEqual(0.0, stats.stddev)

  def test_update_with_multiple_values(self):
    stats = outputs.DistributionSummaryStats()
    stats.update_with_value(1.0)
    stats.update_with_value(3.0)
    stats.update_with_value(2.0)
    stats.update_with_value(0.0)

    # We check stddev separately, as it requires assertAlmostEqual.
    expected_result_except_stddev = outputs.DistributionSummaryStats(
        min=0.0, max=3.0, sum=6.0, sum_of_squares=14.0, count=4.0, mean=1.5)

    stats_except_stddev = dataclasses.replace(stats, stddev=0.0)
    with self.subTest('correct_except_stddev'):
      self.assertEqual(expected_result_except_stddev, stats_except_stddev)

    with self.subTest('correct_stddev'):
      # Value calculated by STDEV in Google Sheets.
      self.assertAlmostEqual(1.2910, stats.stddev, 4)

  def test_add(self):
    stats1 = outputs.DistributionSummaryStats(
        min=0.0,
        max=3.0,
        sum=6.0,
        sum_of_squares=14.0,
        count=4.0,
        mean=1.5,
        stddev=1.0)
    stats2 = outputs.DistributionSummaryStats(
        min=1.0,
        max=4.0,
        sum=10.0,
        sum_of_squares=30.0,
        count=4.0,
        mean=2.5,
        stddev=1.0)

    # We check stddev separately, as it requires assertAlmostEqual.
    expected_result_except_stddev = outputs.DistributionSummaryStats(
        min=0.0, max=4.0, sum=16.0, sum_of_squares=44.0, count=8.0, mean=2.0)

    actual_result = stats1 + stats2
    actual_result_except_stddev = dataclasses.replace(actual_result, stddev=0.0)

    with self.subTest('correct_except_stddev'):
      self.assertEqual(expected_result_except_stddev,
                       actual_result_except_stddev)

    with self.subTest('correct_stddev'):
      # The stddev (and mean) is calculated based on the values of the other
      # attributes.
      self.assertAlmostEqual(1.3093, actual_result.stddev, 4)


class ExampleAttemptCountersTest(absltest.TestCase):

  def test_total_should_be_sum_of_the_individual_buckets(self):
    counters = outputs.ExampleAttemptCounters(
        already_in_context=1,
        ambiguous=2,
        duplicate=4,
        max_derivation_level_reached=8,
        missing_target_rule=16,
        unparseable=32,
        unable_to_create_negative_example=64,
        valid=128,
    )
    expected_total = 1 + 2 + 4 + 8 + 16 + 32 + 64 + 128
    self.assertEqual(expected_total, counters.get_total())

  def test_valid_fraction_should_be_valid_attempts_divided_by_total(self):
    counters = outputs.ExampleAttemptCounters(
        ambiguous=5, unparseable=5, valid=10)
    self.assertEqual(0.5, counters.get_valid_fraction())

  def test_valid_fraction_should_be_zero_if_no_attempts_were_made(self):
    counters = outputs.ExampleAttemptCounters()
    self.assertEqual(0.0, counters.get_valid_fraction())


class ExampleCountersTest(parameterized.TestCase):

  def test_total_should_be_zero_when_counters_are_empty(self):
    counters = outputs.ExampleCounters()
    self.assertEqual(0, counters.get_total())

  def test_total_should_be_sum_of_the_individual_buckets(self):
    counters = outputs.ExampleCounters()
    counters.by_request_type[cl.RequestType.NON_RULE] += 1
    counters.by_request_type[cl.RequestType.RULE] += 2
    expected_total = 1 + 2
    self.assertEqual(expected_total, counters.get_total())

  def test_update_with_example_and_context(self):
    mutable_context = cl.ExampleSet()
    omitted_rule = 'omitted rule'
    explicit_rule = 'explicit rule'
    unreliable_rule = 'unreliable rule'
    distractor_rule = 'distractor rule'
    hidden_rule = 'hidden rule'
    mutable_context.add_omitted_rule(omitted_rule)
    mutable_context.add_explicit_rule(
        explicit_rule,
        test_utils.create_example_from_explicit_rule(explicit_rule))
    mutable_context.mark_rule_as_unreliable(unreliable_rule)
    mutable_context.add_unreliable_rule(
        unreliable_rule,
        illustrative_examples=[
            cl.Example(request='s', reply='t'),
            cl.Example(request='u', reply='v')
        ])
    mutable_context.add_hidden_rule(
        hidden_rule,
        illustrative_examples=[
            cl.Example(request='a', reply='b'),
            cl.Example(request='c', reply='d')
        ])
    mutable_context.metadata.rule_reply_by_hidden_rule = {
        hidden_rule: cl.RuleReply.TRUE
    }
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    example = cl.Example(
        request='e',
        reply='f',
        qualifier=cl.Qualifier.M,
        metadata=cl.ExampleMetadata(
            rules={omitted_rule, explicit_rule, distractor_rule, hidden_rule},
            derivation_level=3,
            num_variables=4,
            applied_edits=['some_edit'],
            distractor_rules_by_unreliable_rule={
                unreliable_rule: [distractor_rule]
            },
            input_length_standard=5,
            output_length_standard=6,
            input_length_compact=7,
            output_length_compact=8))
    counters = outputs.ExampleCounters()

    counters.update_with_example_and_context(example, context)
    expected = outputs.ExampleCounters(
        by_request_type={cl.RequestType.NON_RULE: 1},
        by_example_type={cl.ExampleType.NONRULE_KNOWN_M: 1},
        by_qualifier={cl.Qualifier.M: 1},
        by_knownness={'KNOWN_MONOTONIC': 1},
        by_derivation_level={3: 1},
        by_num_variables={4: 1},
        by_derivation_level_and_num_variables={'(3, 4)': 1},
        by_num_omitted_rules={1: 1},
        by_num_explicit_rules={1: 1},
        by_num_unreliable_rules={1: 1},
        by_num_hidden_rules={1: 1},
        by_num_rules={4: 1},
        by_applied_edits={'some_edit': 1},
        by_triviality={cl.Triviality.NON_TRIVIAL: 1},
        input_length_stats_standard=outputs.DistributionSummaryStats(
            min=5.0,
            max=5.0,
            sum=5.0,
            sum_of_squares=25.0,
            count=1.0,
            mean=5.0,
            stddev=0.0),
        output_length_stats_standard=outputs.DistributionSummaryStats(
            min=6.0,
            max=6.0,
            sum=6.0,
            sum_of_squares=36.0,
            count=1.0,
            mean=6.0,
            stddev=0.0),
        input_length_stats_compact=outputs.DistributionSummaryStats(
            min=7.0,
            max=7.0,
            sum=7.0,
            sum_of_squares=49.0,
            count=1.0,
            mean=7.0,
            stddev=0.0),
        output_length_stats_compact=outputs.DistributionSummaryStats(
            min=8.0,
            max=8.0,
            sum=8.0,
            sum_of_squares=64.0,
            count=1.0,
            mean=8.0,
            stddev=0.0))
    _assert_counter_equal(self, counters, expected)

  def test_get_fraction_by_example_type_by_request_type(self):
    counters = outputs.ExampleCounters(
        by_example_type={
            cl.ExampleType.NONRULE_KNOWN_D: 1,
            cl.ExampleType.NONRULE_KNOWN_M: 2,
            cl.ExampleType.NONRULE_UNKNOWN_D: 3,
            cl.ExampleType.RULE_KNOWN_FALSE_D: 11,
            cl.ExampleType.RULE_KNOWN_FALSE_M: 12,
            cl.ExampleType.RULE_KNOWN_TRUE_D: 13,
            cl.ExampleType.RULE_KNOWN_TRUE_M: 14,
            cl.ExampleType.RULE_UNKNOWN_D: 15
        },
        by_request_type={
            cl.RequestType.NON_RULE: 1 + 2 + 3,  # total: 6
            cl.RequestType.RULE: 11 + 12 + 13 + 14 + 15  # total: 65
        })
    fraction_by_example_type_by_request_type = (
        counters.get_fraction_by_example_type_by_request_type())
    expected = {
        cl.RequestType.NON_RULE: {
            cl.ExampleType.NONRULE_KNOWN_D: 1 / 6,
            cl.ExampleType.NONRULE_KNOWN_M: 2 / 6,
            cl.ExampleType.NONRULE_UNKNOWN_D: 3 / 6
        },
        cl.RequestType.RULE: {
            cl.ExampleType.RULE_KNOWN_FALSE_D: 11 / 65,
            cl.ExampleType.RULE_KNOWN_FALSE_M: 12 / 65,
            cl.ExampleType.RULE_KNOWN_TRUE_D: 13 / 65,
            cl.ExampleType.RULE_KNOWN_TRUE_M: 14 / 65,
            cl.ExampleType.RULE_UNKNOWN_D: 15 / 65
        }
    }

    for request_type, expected_fraction_by_example_type in expected.items():
      for example_type, expected_fraction in (
          expected_fraction_by_example_type.items()):
        calculated_fraction = (
            fraction_by_example_type_by_request_type[request_type][example_type]
        )
        self.assertEqual(expected_fraction, calculated_fraction)

  def test_get_fraction_by_example_type_by_request_type_missing_request_type(
      self):
    counters = outputs.ExampleCounters(
        by_example_type={
            cl.ExampleType.NONRULE_KNOWN_D: 1,
            cl.ExampleType.NONRULE_KNOWN_M: 2,
            cl.ExampleType.NONRULE_UNKNOWN_D: 3,
        },
        by_request_type={
            cl.RequestType.NON_RULE: 1 + 2 + 3,  # total: 6
            cl.RequestType.RULE: 0
        })
    fraction_by_example_type_by_request_type = (
        counters.get_fraction_by_example_type_by_request_type())

    # There are no rule examples, so the fractions should be zero for those.
    for fraction in (
        fraction_by_example_type_by_request_type[cl.RequestType.RULE].values()):
      self.assertEqual(fraction, 0.0)


class ContextCountersTest(absltest.TestCase):

  def test_update_with_context_should_count_omitted_rules(self):
    mutable_context = cl.ExampleSet()
    mutable_context.add_omitted_rule('some rule')
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    counters = outputs.ContextCounters()

    counters.update_with_context(context)
    expected = outputs.ContextCounters(
        total=1,
        by_num_omitted_rules={1: 1},
        by_num_explicit_rules={0: 1},
        by_num_unreliable_rules={0: 1},
        by_num_hidden_rules={0: 1},
        by_num_hidden_true_rules={0: 1},
        by_num_hidden_unknown_rules={0: 1},
        by_num_rules={1: 1},
        by_num_examples={0: 1})
    _assert_counter_equal(self, counters, expected)

  def test_update_with_context_should_count_explicit_rules_and_examples(self):
    mutable_context = cl.ExampleSet()
    rule = 'some rule'
    mutable_context.add_explicit_rule(
        rule, test_utils.create_example_from_explicit_rule(rule))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    counters = outputs.ContextCounters()

    counters.update_with_context(context)
    expected = outputs.ContextCounters(
        total=1,
        by_num_omitted_rules={0: 1},
        by_num_explicit_rules={1: 1},
        by_num_unreliable_rules={0: 1},
        by_num_hidden_rules={0: 1},
        by_num_hidden_true_rules={0: 1},
        by_num_hidden_unknown_rules={0: 1},
        by_num_rules={1: 1},
        by_num_examples={1: 1})
    _assert_counter_equal(self, counters, expected)

  def test_update_with_context_should_count_unreliable_rules_and_examples(self):
    mutable_context = cl.ExampleSet()
    rule = 'some rule'
    mutable_context.mark_rule_as_unreliable(rule)
    mutable_context.add_unreliable_rule(
        rule,
        illustrative_examples=[
            cl.Example(request='a', reply='b'),
            cl.Example(request='c', reply='d')
        ])
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    counters = outputs.ContextCounters()

    counters.update_with_context(context)
    expected = outputs.ContextCounters(
        total=1,
        by_num_omitted_rules={0: 1},
        by_num_explicit_rules={0: 1},
        by_num_unreliable_rules={1: 1},
        by_num_hidden_rules={0: 1},
        by_num_hidden_true_rules={0: 1},
        by_num_hidden_unknown_rules={0: 1},
        by_num_rules={1: 1},
        by_num_examples={2: 1})
    _assert_counter_equal(self, counters, expected)

  def test_update_with_context_should_count_hidden_rules_and_examples(self):
    mutable_context = cl.ExampleSet()
    rule = 'some rule'
    mutable_context.add_hidden_rule(
        rule,
        illustrative_examples=[
            cl.Example(request='a', reply='b'),
            cl.Example(request='c', reply='d')
        ])
    mutable_context.metadata.rule_reply_by_hidden_rule = {
        rule: cl.RuleReply.UNKNOWN
    }
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    counters = outputs.ContextCounters()

    counters.update_with_context(context)
    expected = outputs.ContextCounters(
        total=1,
        by_num_omitted_rules={0: 1},
        by_num_explicit_rules={0: 1},
        by_num_unreliable_rules={0: 1},
        by_num_hidden_rules={1: 1},
        by_num_hidden_true_rules={0: 1},
        by_num_hidden_unknown_rules={1: 1},
        by_num_rules={1: 1},
        by_num_examples={2: 1})
    _assert_counter_equal(self, counters, expected)

  def test_update_with_example_group_should_count_examples(self):
    example_group = cl.ExampleGroup()
    known_nonrule_example = cl.Example(request='a1', reply='b1')
    unknown_nonrule_example = cl.Example(
        request='a2',
        reply=cl.RuleReply.UNKNOWN,
        qualifier=cl.Qualifier.D,
        metadata=cl.ExampleMetadata(original_reply='b2'))
    positive_rule_example = cl.Example(request='a3', reply=cl.RuleReply.TRUE)
    negative_rule_example = cl.Example(request='a4', reply=cl.RuleReply.FALSE)
    unknown_rule_example = cl.Example(
        request='a5', reply=cl.RuleReply.UNKNOWN, qualifier=cl.Qualifier.D)

    example_group.add_example(known_nonrule_example)
    example_group.add_example(unknown_nonrule_example)
    example_group.add_example(positive_rule_example)
    example_group.add_example(negative_rule_example)
    example_group.add_example(unknown_rule_example)

    counters = outputs.ContextCounters()

    counters.update_with_example_group(example_group)
    expected = outputs.ContextCounters(
        by_num_unknown_rule_top_level_examples={1: 1},
        by_num_positive_rule_top_level_examples={1: 1},
        by_num_negative_rule_top_level_examples={1: 1},
        by_num_unknown_nonrule_top_level_examples={1: 1})
    _assert_counter_equal(self, counters, expected)


class RuleCountersTest(parameterized.TestCase):

  def test_update_with_context_should_count_explicit_rules(self):
    mutable_context = cl.ExampleSet()
    rule = 'some rule'
    mutable_context.add_explicit_rule(
        rule, test_utils.create_example_from_explicit_rule(rule))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)

    counters = outputs.RuleBreakdownCounters()
    counters.update_with_context(context)

    expected_all = outputs.RuleCounters(
        total=1,
        by_num_context_examples={1: 1},
        by_num_reliable_context_examples={1: 1},
        by_min_reliable_derivation_level={0: 1})
    expected = outputs.RuleBreakdownCounters(
        all=expected_all, explicit=expected_all)

    _assert_counter_equal(self, counters, expected)

  @parameterized.named_parameters(('hidden_true', cl.RuleReply.TRUE),
                                  ('hidden_unknown', cl.RuleReply.UNKNOWN))
  def test_update_with_context_should_count_hidden_and_distractor_rules(
      self, rule_reply):
    mutable_context = cl.ExampleSet()
    rule = 'some rule'
    unreliable_rule = 'unreliable rule'
    distractor_rule = 'distractor rule'
    mutable_context.mark_rule_as_unreliable(unreliable_rule)
    mutable_context.add_hidden_rule(
        rule,
        illustrative_examples=[
            cl.Example(
                request='a',
                reply='b',
                metadata=cl.ExampleMetadata(rules={rule}, derivation_level=8)),
            cl.Example(
                request='c',
                reply='d',
                # We include a distractor rule here so as to verify that
                metadata=cl.ExampleMetadata(
                    rules={rule, distractor_rule},
                    derivation_level=2,
                    distractor_rules_by_unreliable_rule=({
                        unreliable_rule: [distractor_rule]
                    })))
        ])
    mutable_context.metadata.rule_reply_by_hidden_rule = {rule: rule_reply}
    context = cl.FrozenExampleSet.from_example_set(mutable_context)

    counters = outputs.RuleBreakdownCounters()
    counters.update_with_context(context)

    expected_all = outputs.RuleCounters(
        total=3,
        by_num_context_examples={
            0: 1,
            1: 1,
            2: 1
        },
        by_num_reliable_context_examples={
            0: 2,
            1: 1
        },
        by_min_reliable_derivation_level={
            # Doesn't count rules that appeared in no examples, as the minimum
            # reliable derivation level is not well defined in that case.
            8: 1
        })
    expected_hidden = outputs.RuleCounters(
        total=1,
        by_num_context_examples={2: 1},
        by_num_reliable_context_examples={1: 1},
        by_min_reliable_derivation_level={8: 1})
    expected_unreliable = outputs.RuleCounters(
        total=1,
        by_num_context_examples={0: 1},
        by_num_reliable_context_examples={0: 1})
    expected_distractor = outputs.RuleCounters(
        total=1,
        by_num_context_examples={1: 1},
        by_num_reliable_context_examples={0: 1})

    expected = outputs.RuleBreakdownCounters(
        all=expected_all,
        unreliable=expected_unreliable,
        distractor=expected_distractor)
    if rule_reply == cl.RuleReply.TRUE:
      expected.hidden_true = expected_hidden
    else:
      expected.hidden_unknown = expected_hidden

    # We check the RuleCounters objects individually first so as to yield
    # smaller and more easily debuggable messages in case of failure.
    with self.subTest('all_rule_counters'):
      _assert_counter_equal(self, counters.all, expected.all)

    with self.subTest('explicit_rule_counters'):
      # This should be empty, but checking anyway for completeness.
      _assert_counter_equal(self, counters.explicit, expected.explicit)

    with self.subTest('hidden_true_rule_counters'):
      _assert_counter_equal(self, counters.hidden_true, expected.hidden_true)

    with self.subTest('hidden_unknown_rule_counters'):
      _assert_counter_equal(self, counters.hidden_unknown,
                            expected.hidden_unknown)

    with self.subTest('unreliable_rule_counters'):
      _assert_counter_equal(self, counters.unreliable, expected.unreliable)

    with self.subTest('distractor_rule_counters'):
      _assert_counter_equal(self, counters.distractor, expected.distractor)

    # Now we check the whole RuleBreakdownCounters object just in case we missed
    # something in the finer-grained checks above.
    with self.subTest('full_nested_counters'):
      _assert_counter_equal(self, counters, expected)

  def test_update_with_context_should_count_illustrated_rule_substitutions(
      self):
    mutable_context = cl.ExampleSet(
        metadata=cl.ExampleSetMetadata(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            rules=[
                'rule1', 'rule2', 'rule3', 'rule4', 'rule5', 'rule6', 'rule7'
            ],
            hidden_rules=['rule1', 'rule2', 'rule3', 'rule4', 'rule5', 'rule6'],
            unreliable_rules=['rule7'],
            variable_substitutions_by_rule={
                # Two rules with num_context_variable_substitutions = 2.
                'rule1': {
                    'x1': {'a', 'b', 'c'},
                    'x2': {'d', 'e'},
                },
                'rule2': {
                    'x2': {'d', 'e'},
                },
                # One rule with num_context_variable_substitutions = 0.
                'rule3': {
                    'x1': {}
                },
                # Three rules with num_context_variable_substitutions = MAX.
                'rule4': {},
                'rule5': {
                    'x1': {'x1'}
                },
                'rule6': {
                    'x1': {'x2', 'f'}
                },
                # This rule has num_context_variable_substitutions = 3, but
                # should be counted separately because it is unreliable.
                'rule7': {
                    'x1': {'g', 'h', 'i'}
                },
            },
            outer_substitutions_by_rule={
                # Two rules with num_context_outer_substitutions = MAX.
                'rule1': {'__'},
                'rule2': {'__', '__ c'},
                # One rule with num_context_outer_substitutions = 2.
                'rule3': {'__ a', 'b __'},
                # This rule has num_context_outer_substitutions = 3, but should
                # be counted separately because it is unreliable.
                'rule7': {'a __', 'b __', 'c __'},
            },
            reliable_variable_substitutions_by_rule={
                # One rule with num_context_reliable_variable_substitutions = 1.
                'rule1': {
                    'x1': {'a'}
                },
            },
            reliable_outer_substitutions_by_rule={
                # One rule with num_context_reliable_outer_substitutions = 2.
                'rule1': {'__ f', 'g __'}
            }))
    mutable_context.metadata.rule_reply_by_hidden_rule = {
        rule: cl.RuleReply.TRUE
        for rule in mutable_context.metadata.hidden_rules
    }
    context = cl.FrozenExampleSet.from_example_set(mutable_context)

    counters = outputs.RuleBreakdownCounters()
    counters.update_with_context(context)

    with self.subTest('by_num_context_variable_substitutions_hidden_true'):
      self.assertDictEqual(
          {
              0: 1,
              2: 2,
              production_trees.INFINITE_SUBSTITUTIONS: 3,
          }, counters.hidden_true.by_num_context_variable_substitutions)

    with self.subTest('by_num_context_variable_substitutions_unreliable'):
      self.assertDictEqual({
          3: 1,
      }, counters.unreliable.by_num_context_variable_substitutions)

    with self.subTest('by_num_context_variable_substitutions_all'):
      # Should be sum of the hidden_true and unreliable counters.
      self.assertDictEqual(
          {
              0: 1,
              2: 2,
              3: 1,
              production_trees.INFINITE_SUBSTITUTIONS: 3,
          }, counters.all.by_num_context_variable_substitutions)

    with self.subTest('by_num_context_outer_substitutions_hidden_true'):
      self.assertDictEqual({
          2: 1,
          production_trees.INFINITE_SUBSTITUTIONS: 2,
      }, counters.hidden_true.by_num_context_outer_substitutions)

    with self.subTest('by_num_context_outer_substitutions_unreliable'):
      self.assertDictEqual({
          3: 1,
      }, counters.unreliable.by_num_context_outer_substitutions)

    with self.subTest('by_num_context_outer_substitutions_all'):
      # Should be sum of the hidden_true and unreliable counters.
      self.assertDictEqual(
          {
              2: 1,
              3: 1,
              production_trees.INFINITE_SUBSTITUTIONS: 2,
          }, counters.all.by_num_context_outer_substitutions)

    with self.subTest('by_num_context_reliable_variable_substitutions_all'):
      # Just verifying that the 'reliable' variant gets populated too.
      # (Details of the calculation should be the same as above.)
      self.assertDictEqual({
          1: 1,
      }, counters.all.by_num_context_reliable_variable_substitutions)

    with self.subTest('by_num_context_reliable_outer_substitutions_all'):
      # Just verifying that the 'reliable' variant gets populated too.
      self.assertDictEqual({
          2: 1,
      }, counters.all.by_num_context_reliable_outer_substitutions)


class GenerationTimingTest(absltest.TestCase):

  def test_total_should_be_sum_of_the_individual_phases(self):
    timing = outputs.GenerationTiming(
        generate_dataset=1.0,
        split_dataset=2.0,
        summarize_dataset=3.0,
    )
    expected_total = 1.0 + 2.0 + 3.0
    self.assertEqual(timing.total, expected_total)


class GenerationStatsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('empty', outputs.GenerationStats()),
      ('non_empty',
       outputs.GenerationStats(
           counters=outputs.GenerationCounters(
               example_attempts=outputs.ExampleAttemptCounters(
                   duplicate=1, valid=2),
               examples=outputs.ExampleCounters(
                   by_request_type=collections.defaultdict(
                       int, {cl.RequestType.NON_RULE: 3}),
                   by_example_type=collections.defaultdict(
                       int, {cl.ExampleType.NONRULE_KNOWN_D: 5}),
                   by_qualifier=collections.defaultdict(int,
                                                        {cl.Qualifier.D: 4})),
               contexts=outputs.ContextCounters(),
               errors=outputs.GenerationErrorCounters()),
           timing=outputs.GenerationTiming(generate_dataset=3.14))))
  def test_json_serialization_roundtrip(self, stats):
    try:
      stats_as_json = stats.to_json()
    except TypeError:
      self.fail(f'Exception raised when converting GenerationStats to JSON: '
                f'{traceback.format_exc()}')
    logging.info('Original GenerationStats:\n%s', stats)
    logging.info('GenerationStats as JSON:\n%s', stats_as_json)

    restored_stats = outputs.GenerationStats.from_json(stats_as_json)
    self.assertEqual(stats, restored_stats)

  def test_is_addable(self):
    stats1 = outputs.GenerationStats(
        counters=outputs.GenerationCounters(
            example_attempts=outputs.ExampleAttemptCounters(
                already_in_context=1,
                ambiguous=2,
                duplicate=3,
                max_derivation_level_reached=4,
                missing_target_rule=5,
                unparseable=6,
                unable_to_create_negative_example=7,
                valid=8,
            ),
            examples=outputs.ExampleCounters(
                by_request_type=collections.Counter({
                    cl.RequestType.NON_RULE: 10,
                    cl.RequestType.RULE: 11
                }),
                by_example_type=collections.Counter({
                    cl.ExampleType.NONRULE_KNOWN_D: 60,
                    cl.ExampleType.RULE_UNKNOWN_D: 61
                }),
                by_qualifier=collections.Counter({
                    cl.Qualifier.D: 40,
                    cl.Qualifier.M: 41
                }),
                by_triviality=collections.Counter({
                    cl.Triviality.NEGATION: 50,
                    cl.Triviality.NON_TRIVIAL: 51
                }),
                input_length_stats_standard=outputs.DistributionSummaryStats(
                    max=70.0, sum=71.0),
                output_length_stats_standard=outputs.DistributionSummaryStats(
                    max=72.0, sum=73.0),
                input_length_stats_compact=outputs.DistributionSummaryStats(
                    max=74.0, sum=75.0),
                output_length_stats_compact=outputs.DistributionSummaryStats(
                    max=76.0, sum=77.0)),
            context_attempts=outputs.ContextAttemptCounters(
                poor_illustration_quality=101,
                unreliable_rule_illustrated_only_one_way=102,
                exceeded_max_input_length=103,
            ),
            contexts=outputs.ContextCounters(
                total=20,
                by_num_omitted_rules=collections.Counter({
                    0: 21,
                    1: 22
                }),
                by_num_explicit_rules=collections.Counter({
                    2: 31,
                    3: 32
                }),
                by_num_hidden_rules=collections.Counter({
                    4: 41,
                    5: 42
                }),
                by_num_hidden_true_rules=collections.Counter({
                    4: 43,
                    5: 44
                }),
                by_num_hidden_unknown_rules=collections.Counter({
                    4: 45,
                    5: 46
                }),
                by_num_rules=collections.Counter({
                    6: 51,
                    7: 52
                }),
                by_num_examples=collections.Counter({
                    8: 61,
                    9: 62
                }),
                by_num_unknown_rule_top_level_examples=collections.Counter({
                    10: 71,
                    11: 72
                }),
                by_num_positive_rule_top_level_examples=collections.Counter({
                    12: 81,
                    13: 82
                }),
                by_num_negative_rule_top_level_examples=collections.Counter({
                    14: 91,
                    15: 92
                }),
                by_num_unknown_nonrule_top_level_examples=collections.Counter({
                    16: 101,
                    17: 102
                })),
            rules=outputs.RuleBreakdownCounters(
                all=outputs.RuleCounters(
                    total=1,
                    by_num_context_examples=collections.Counter({
                        4: 14,
                        5: 15,
                    }),
                    by_num_reliable_context_examples=collections.Counter({
                        6: 16,
                        7: 17,
                    }),
                    by_min_reliable_derivation_level=collections.Counter({
                        8: 18,
                        9: 19,
                    }),
                    by_num_context_variable_substitutions=collections.Counter({
                        10: 20,
                        11: 21,
                    }),
                    by_num_context_outer_substitutions=collections.Counter({
                        12: 22,
                        13: 23,
                    }),
                    by_num_context_reliable_variable_substitutions=collections
                    .Counter({
                        10: 24,
                        11: 25,
                    }),
                    by_num_context_reliable_outer_substitutions=collections
                    .Counter({
                        12: 26,
                        13: 27,
                    })),
                explicit=outputs.RuleCounters(total=30),
                hidden_true=outputs.RuleCounters(total=31),
                hidden_unknown=outputs.RuleCounters(total=32),
                unreliable=outputs.RuleCounters(total=33),
                distractor=outputs.RuleCounters(total=34),
                omitted=outputs.RuleCounters(total=35)),
            errors=outputs.GenerationErrorCounters(
                failed_to_illustrate_target_rule=40,
                failed_to_generate_example_of_desired_request_type=41,
                failed_to_generate_derived_production=42,
                failed_to_generate_context=43,
                failed_to_generate_grammar=44)),
        timing=outputs.GenerationTiming(
            generate_dataset=1.0, split_dataset=2.0))
    expected_sum = outputs.GenerationStats(
        counters=outputs.GenerationCounters(
            example_attempts=outputs.ExampleAttemptCounters(
                already_in_context=2,
                ambiguous=4,
                duplicate=6,
                max_derivation_level_reached=8,
                missing_target_rule=10,
                unparseable=12,
                unable_to_create_negative_example=14,
                valid=16,
            ),
            examples=outputs.ExampleCounters(
                by_request_type={
                    cl.RequestType.NON_RULE: 20,
                    cl.RequestType.RULE: 22
                },
                by_example_type={
                    cl.ExampleType.NONRULE_KNOWN_D: 120,
                    cl.ExampleType.RULE_UNKNOWN_D: 122
                },
                by_qualifier={
                    cl.Qualifier.D: 80,
                    cl.Qualifier.M: 82
                },
                by_triviality={
                    cl.Triviality.NEGATION: 100,
                    cl.Triviality.NON_TRIVIAL: 102
                },
                input_length_stats_standard=outputs.DistributionSummaryStats(
                    # Note that when adding DistributionSummaryStats, the "max"
                    # field of the result is the max of the original max values,
                    # not the sum of them.
                    max=70.0,
                    sum=142.0),
                output_length_stats_standard=outputs.DistributionSummaryStats(
                    max=72.0, sum=146.0),
                input_length_stats_compact=outputs.DistributionSummaryStats(
                    max=74.0, sum=150.0),
                output_length_stats_compact=outputs.DistributionSummaryStats(
                    max=76.0, sum=154.0)),
            context_attempts=outputs.ContextAttemptCounters(
                poor_illustration_quality=202,
                unreliable_rule_illustrated_only_one_way=204,
                exceeded_max_input_length=206,
            ),
            contexts=outputs.ContextCounters(
                total=40,
                by_num_omitted_rules={
                    0: 42,
                    1: 44
                },
                by_num_explicit_rules={
                    2: 62,
                    3: 64
                },
                by_num_hidden_rules={
                    4: 82,
                    5: 84
                },
                by_num_hidden_true_rules={
                    4: 86,
                    5: 88
                },
                by_num_hidden_unknown_rules={
                    4: 90,
                    5: 92
                },
                by_num_rules={
                    6: 102,
                    7: 104
                },
                by_num_examples={
                    8: 122,
                    9: 124
                },
                by_num_unknown_rule_top_level_examples={
                    10: 142,
                    11: 144
                },
                by_num_positive_rule_top_level_examples={
                    12: 162,
                    13: 164
                },
                by_num_negative_rule_top_level_examples={
                    14: 182,
                    15: 184
                },
                by_num_unknown_nonrule_top_level_examples={
                    16: 202,
                    17: 204
                }),
            rules=outputs.RuleBreakdownCounters(
                all=outputs.RuleCounters(
                    total=2,
                    by_num_context_examples=collections.Counter({
                        4: 28,
                        5: 30,
                    }),
                    by_num_reliable_context_examples=collections.Counter({
                        6: 32,
                        7: 34,
                    }),
                    by_min_reliable_derivation_level=collections.Counter({
                        8: 36,
                        9: 38,
                    }),
                    by_num_context_variable_substitutions=collections.Counter({
                        10: 40,
                        11: 42,
                    }),
                    by_num_context_outer_substitutions=collections.Counter({
                        12: 44,
                        13: 46,
                    }),
                    by_num_context_reliable_variable_substitutions=collections
                    .Counter({
                        10: 48,
                        11: 50,
                    }),
                    by_num_context_reliable_outer_substitutions=collections
                    .Counter({
                        12: 52,
                        13: 54,
                    })),
                explicit=outputs.RuleCounters(total=60),
                hidden_true=outputs.RuleCounters(total=62),
                hidden_unknown=outputs.RuleCounters(total=64),
                unreliable=outputs.RuleCounters(total=66),
                distractor=outputs.RuleCounters(total=68),
                omitted=outputs.RuleCounters(total=70)),
            errors=outputs.GenerationErrorCounters(
                failed_to_illustrate_target_rule=80,
                failed_to_generate_example_of_desired_request_type=82,
                failed_to_generate_derived_production=84,
                failed_to_generate_context=86,
                failed_to_generate_grammar=88)),
        timing=outputs.GenerationTiming(
            generate_dataset=2.0, split_dataset=4.0))
    self.assertEqual(stats1 + stats1, expected_sum)


if __name__ == '__main__':
  absltest.main()
