# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic behavior tests for the dataset_generation module.

This file contains tests that verifies the dataset_generation module's basic
behavior.  Tests for more specific edge cases (controlled by sampling options)
and tests for counters are in dataset_generation_additional_test.py.
"""

from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_generation
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import test_utils


class DatasetGenerationBasicBehaviorTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_generate_dataset_basic_behavior(self, unused_mock):
    options = inputs.GenerationOptions(
        sampling=test_utils.create_sampling_options(
            num_contexts=2, non_rule_fraction=1.0))
    counters = outputs.GenerationCounters()

    # Note that we are performing as many different assertions as possible in
    # a single test case (organized using subTests), as each call to
    # generate_dataset is rather expensive.
    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    # Note that we are purposefully avoiding use of self.assertLen here, as in
    # case of failure, assertLen would log the entire dataset, which is unwieldy
    # (>1MB) and generally unhelpful.
    actual_num_top_level_examples = len(dataset.to_example_set())

    with self.subTest('dataset_should_contain_requested_number_of_examples'):
      self.assertEqual(
          test_utils.get_expected_num_top_level_examples(options.sampling),
          actual_num_top_level_examples,
          test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('dataset_should_contain_requested_number_of_contexts'):
      self.assertEqual(options.sampling.num_contexts, counters.contexts.total,
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('dataset_should_contain_one_example_group_per_context'):
      self.assertLen(dataset.example_groups, options.sampling.num_contexts,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('number_of_examples_in_dataset_should_match_counters'):
      self.assertEqual(actual_num_top_level_examples,
                       counters.examples.get_total(),
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('all_examples_should_be_non_rule_requests'):
      self.assertEqual(
          counters.examples.get_total(),
          counters.examples.by_request_type[cl.RequestType.NON_RULE],
          test_utils.get_dataset_summary(counters, dataset))

    # Note that since the dataset generation process is based on random
    # sampling, not all of the desired behavior can be described crisply in a
    # "this must be true; this must not be true" fashion. To minimize the chance
    # of false alarms, we focus most of our tests on just the deterministic
    # subset of behaviors of interest. The following two assertions, however,
    # seek to verify that the use of random sampling itself is being performed
    # in a reasonably efficient way. As we are concerned primarily with avoiding
    # blatantly pathological behaviors, we set the thresholds for the below
    # tests conservatively, so that the chance of test failure with a reasonable
    # implementation should be vanishingly low. In reality, under the default
    # settings, we would expect failed_to_illustrate_target_rule to normally be
    # zero, while the valid raction would normally be well over 30%. If either
    # of the below tests fail, it is almost certainly a sign of a bug or design
    # flaw in the dataset generation logic.

    with self.subTest('should_rarely_if_ever_fail_to_illustrate_target_rule'):
      self.assertLessEqual(counters.errors.failed_to_illustrate_target_rule, 10,
                           test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('a_reasonable_fraction_of_attempts_should_succeed'):
      self.assertGreaterEqual(counters.example_attempts.get_valid_fraction(),
                              0.05,
                              test_utils.get_dataset_summary(counters, dataset))

  def test_generate_dataset_skips_duplicates_at_top_level(self):
    # Here we request a large number of requests per context while generating
    # grammars that are so small that only one unique example can be generated
    # from each. This guarantees that a large and predictable number of
    # duplicate top-level examples will be generated.
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=1,
            num_precedence_levels=1,
            num_functions_per_level=1,
            has_pass_through_rules=False,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=2,
            num_requests_per_context=50,
            max_attempts_per_example=5,
            # We make all the rules in the context explicit, so that sampling is
            # done only at the top level. That way we know exactly how many
            # successful and unsuccessful attempts to expect in the counters.
            explicit_fraction=1.0,
            non_rule_fraction=1.0))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)
    actual_num_top_level_examples = len(dataset.to_example_set())

    # Only one unique example can be generated for each of these grammars.
    expected_successful_top_level_examples = options.sampling.num_contexts

    expected_total_attempts = (
        test_utils.get_expected_num_top_level_examples(options.sampling) *
        options.sampling.max_attempts_per_example)

    # We sample productions from the inference engine when generating examples
    # with known reply, and in this case every distinct example is only yielded
    # once.
    expected_total_attempts = options.sampling.num_contexts

    with self.subTest('should_generate_as_many_unique_examples_as_possible'):
      self.assertEqual(expected_successful_top_level_examples,
                       actual_num_top_level_examples,
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest(
        'should_give_up_after_performing_the_requested_number_of_attempts'):
      self.assertEqual(
          expected_total_attempts - expected_successful_top_level_examples,
          counters.example_attempts.duplicate,
          test_utils.get_dataset_summary(counters, dataset))

    with self.subTest(
        'should_count_an_error_when_num_requests_per_context_is_not_reached'):
      self.assertEqual(
          # We should hit this error once per grammar (i.e., once per context).
          options.sampling.num_contexts,
          counters.errors.failed_to_generate_example_of_desired_request_type,
          test_utils.get_dataset_summary(counters, dataset))

  def test_generate_dataset_skips_duplicates_within_context(self):
    # Here we again generate very small grammars from which only one unique
    # example can be generated from each. This time, however, we limit the
    # number of top-level examples generated while requesting a large number of
    # illustrative examples for each hidden rule so that the duplicates that
    # need to be skipped are all inside the context.
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=1,
            num_precedence_levels=1,
            num_functions_per_level=1,
            has_pass_through_rules=False,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=2,
            num_requests_per_context=1,
            # We allow only non-rule examples in the context, so we can be sure
            # that there is only one possible example in the context.
            illustrative_example_non_rule_fraction=1.0,
            max_attempts_per_example=20,
            max_attempts_per_context=1,
            # We make all the rules in the context hidden, so that we know
            # exactly how many illustrative examples should be generated.
            explicit_fraction=0.0,
            non_rule_fraction=1.0,
            min_illustrative_examples=1))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    num_examples_generatable_per_grammar = 1
    expected_num_rules_per_grammar = (
        test_utils.get_expected_num_rules_per_grammar(options.grammar))

    expected_num_skipped_duplicates = (
        test_utils.get_expected_num_top_level_examples(options.sampling) *
        (expected_num_rules_per_grammar *
         options.sampling.num_examples_per_hidden_rule *
         options.sampling.max_attempts_per_example -
         num_examples_generatable_per_grammar))

    with self.subTest('should_generate_no_top_level_examples'):
      # Since the only possible example request is already in the context,
      # there should be no top level example.
      self.assertEmpty(dataset.to_example_set(),
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('should_register_each_of_the_rules_in_the_metadata'):
      self.assertLen(dataset.example_groups[0].context.metadata.rules,
                     expected_num_rules_per_grammar,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest(
        'should_generate_as_many_unique_illustrative_examples_as_possible'):
      self.assertLen(dataset.example_groups[0].context,
                     num_examples_generatable_per_grammar,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest(
        'should_give_up_after_performing_the_requested_number_of_attempts'):
      self.assertEqual(expected_num_skipped_duplicates,
                       counters.example_attempts.duplicate,
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('should_count_an_error_when_num_examples_per_hidden_rule_'
                      'is_not_reached'):
      self.assertEqual(
          # We should hit this error once per rule in each context.
          options.sampling.num_contexts * expected_num_rules_per_grammar,
          counters.errors.failed_to_illustrate_target_rule,
          test_utils.get_dataset_summary(counters, dataset))

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_generate_dataset_should_populate_qualifier(self, unused_mock):
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=5,
            num_precedence_levels=3,
            num_functions_per_level=2,
            has_pass_through_rules=True,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            num_requests_per_context=10,
            defeasible_example_fraction=0.5))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)
    example_qualifiers = set()
    for example in dataset.to_flat_examples():
      example_qualifiers.add(example.qualifier)

    with self.subTest('should_contain_defeasible_qualifiers'):
      self.assertEqual(example_qualifiers, {cl.Qualifier.D, cl.Qualifier.M})

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_should_gracefully_abort_if_max_grammar_attempts_is_reached(
      self, unused_mock):
    max_attempts_per_grammar = 3
    # Here we force the grammar generator to generate grammars that are so small
    # that only one unique example can be generated from each, while requesting
    # contexts that contain hidden rules with multiple illustrative examples
    # for each. This guarantees that context generation will repeatedly fail.
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=1,
            num_precedence_levels=1,
            num_functions_per_level=1,
            has_pass_through_rules=False,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            num_requests_per_context=5,
            illustrative_example_non_rule_fraction=1.0,
            num_examples_per_hidden_rule=5,
            explicit_fraction=0.0,
            max_attempts_per_example=5,
            max_attempts_per_context=2,
            max_attempts_per_grammar=max_attempts_per_grammar,
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
        ))
    counters = outputs.GenerationCounters()
    dataset = dataset_generation.generate_dataset(
        options=options,
        counters=counters,
        rng=self.rng,
    )

    with self.subTest('should_generate_no_top_level_examples'):
      self.assertEmpty(dataset.to_example_set(),
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('should_count_failures_to_generate_contex'):
      self.assertEqual(counters.errors.failed_to_generate_context,
                       max_attempts_per_grammar)

    with self.subTest('should_count_failure_to_generate_grammar'):
      self.assertEqual(counters.errors.failed_to_generate_grammar, 1)


if __name__ == '__main__':
  absltest.main()
