# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Additional tests for the dataset_generation module.

This file contains tests that are concerned with the behavior under specific
sampling options (e.g. explicit_fraction  omitted_fraction).
"""

from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import dataset_generation
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import test_utils


class DatasetGenerationAdditionalTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(('all_explicit', 1.0), ('half_hidden', 0.5),
                                  ('all_hidden', 0.0))
  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_generate_dataset_explicit_fraction(self, explicit_fraction,
                                              unused_mock):
    # Here we specify options that ensure the grammar has an even number of
    # rules, so that when we split them 50-50 between explicit and hidden, the
    # number of rules of each type is fully predictable.
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=5,
            num_precedence_levels=3,
            num_functions_per_level=2,
            has_pass_through_rules=True,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            explicit_fraction=explicit_fraction,
            non_rule_fraction=1.0))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    expected_num_explicit_rules_per_context = (
        len(dataset.example_groups[0].context.metadata.rules) *
        options.sampling.explicit_fraction)
    assert expected_num_explicit_rules_per_context.is_integer(), (
        f'explicit_fraction {explicit_fraction} and GrammarOptions should be '
        f'set in this test to values that yield a fixed number of explicit '
        f'rules (currently {expected_num_explicit_rules_per_context})')
    expected_num_hidden_rules_per_context = (
        len(dataset.example_groups[0].context.metadata.rules) -
        expected_num_explicit_rules_per_context)
    # Contexts should contain one example to assert each explicit rule plus
    # the configured number of examples to illustrate each hidden rule.
    expected_num_examples_per_context = (
        expected_num_explicit_rules_per_context +
        expected_num_hidden_rules_per_context *
        options.sampling.num_examples_per_hidden_rule)
    # Sampling is used for generating top-level non-rule requests as well as
    # requests for illustrating hidden rules within the context.
    expected_num_sampled_examples_including_nested = (
        test_utils.get_expected_num_top_level_examples(options.sampling) +
        options.sampling.num_contexts * expected_num_hidden_rules_per_context *
        options.sampling.num_examples_per_hidden_rule)

    with self.subTest('context_should_have_expected_number_of_explicit_rules'):
      self.assertLen(dataset.example_groups[0].context.metadata.explicit_rules,
                     expected_num_explicit_rules_per_context,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('context_should_have_expected_number_of_hidden_rules'):
      self.assertLen(dataset.example_groups[0].context.metadata.hidden_rules,
                     expected_num_hidden_rules_per_context,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('context_should_have_expected_number_of_nested_examples'):
      self.assertLen(dataset.example_groups[0].context,
                     expected_num_examples_per_context,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest(
        'valid_attempts_should_match_expected_number_of_sampled_examples'):
      self.assertEqual(expected_num_sampled_examples_including_nested,
                       counters.example_attempts.valid,
                       test_utils.get_dataset_summary(counters, dataset))

  @parameterized.named_parameters(
      ('all_omitted', 1.0, 0.0), ('some_omitted', 0.25, 0.25),
      ('half_omitted', 0.5, 0.25), ('none_omitted', 0.0, 0.5))
  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_generate_dataset_omitted_fraction(self, omitted_fraction,
                                             explicit_fraction, unused_mock):
    # Here we specify options that ensure the grammar has a number of rules
    # divisible by 4, so that when we split them 50-50 between omitted and
    # illustrated, then split the illustrated rules 50-50 between explicit and
    # hidden, the number of rules of each type is fully predictable.
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=4,
            num_precedence_levels=2,
            num_functions_per_level=3,
            has_pass_through_rules=True,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            omitted_fraction=omitted_fraction,
            explicit_fraction=explicit_fraction,
            non_rule_fraction=1.0))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    expected_num_omitted_rules_per_context = (
        len(dataset.example_groups[0].context.metadata.rules) *
        options.sampling.omitted_fraction)
    assert expected_num_omitted_rules_per_context.is_integer(), (
        f'omitted_fraction {omitted_fraction} and GrammarOptions should be '
        f'set in this test to values that yield a fixed number of omitted '
        f'rules')
    expected_num_explicit_rules_per_context = (
        len(dataset.example_groups[0].context.metadata.rules) *
        options.sampling.explicit_fraction)
    assert expected_num_explicit_rules_per_context.is_integer(), (
        f'explicit_fraction {explicit_fraction} and GrammarOptions should be '
        f'set in this test to values that yield a fixed number of explicit '
        f'rules')
    expected_num_hidden_rules_per_context = (
        len(dataset.example_groups[0].context.metadata.rules) -
        expected_num_omitted_rules_per_context -
        expected_num_explicit_rules_per_context)
    # Contexts should contain one example to assert each explicit rule plus
    # the configured number of examples to illustrate each hidden rule.
    expected_num_examples_per_context = (
        expected_num_explicit_rules_per_context +
        expected_num_hidden_rules_per_context *
        options.sampling.num_examples_per_hidden_rule)
    # Sampling is used for generating top-level non-rule requests as we as
    # requests for illustrating hidden rules within the context.
    expected_num_sampled_examples_including_nested = (
        test_utils.get_expected_num_top_level_examples(options.sampling) +
        options.sampling.num_contexts * expected_num_hidden_rules_per_context *
        options.sampling.num_examples_per_hidden_rule)

    with self.subTest('context_should_have_expected_number_of_omitted_rules'):
      self.assertLen(dataset.example_groups[0].context.metadata.omitted_rules,
                     expected_num_omitted_rules_per_context,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('context_should_have_expected_number_of_explicit_rules'):
      self.assertLen(dataset.example_groups[0].context.metadata.explicit_rules,
                     expected_num_explicit_rules_per_context,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('context_should_have_expected_number_of_hidden_rules'):
      self.assertLen(dataset.example_groups[0].context.metadata.hidden_rules,
                     expected_num_hidden_rules_per_context,
                     test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('context_should_have_expected_number_of_nested_examples'):
      self.assertLen(dataset.example_groups[0].context,
                     expected_num_examples_per_context,
                     test_utils.get_dataset_summary(counters, dataset))

    # Contexts may be discarded, which means that the valid example attempts
    # may exceed the expected examples.
    with self.subTest(
        'valid_attempts_should_match_or_exceed_expected_number_of_sampled_examples'
    ):
      self.assertGreaterEqual(counters.example_attempts.valid,
                              expected_num_sampled_examples_including_nested,
                              test_utils.get_dataset_summary(counters, dataset))


if __name__ == '__main__':
  absltest.main()
