# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic behavior tests for the dataset_generation module.

This file contains tests that verifies the dataset_generation module's basic
behavior.  Tests for more specific edge cases (controlled by sampling options)
and tests for counters are in dataset_generation_additional_test.py.
"""

from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_generation
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import test_utils


class DatasetGenerationBasicBehaviorTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('all_non_rule_requests', 1.0, 0.5, 12, 12, 0, 0),
      ('mixture_of_request_types', 0.5, 0.5, 12, 6, 3, 3),
      ('all_positive_rule_requests', 0.0, 0.0, 12, 0, 12, 0),
      ('all_negative_rule_requests', 0.0, 1.0, 12, 0, 0, 12))
  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_generate_dataset_with_rule_requests(
      self, non_rule_fraction, negative_example_fraction,
      num_top_level_examples, expected_num_non_rule_requests,
      expected_num_positive_rule_requests,
      expected_num_negative_rule_requests, unused_mock):
    # Number of contexts should be kept small so that the tests run fast, but
    # requests per contexts also shouldn't be too large, as otherwise we may
    # need to generate too many rule requests from the same context and not have
    # enough hidden rules to do so. (This latter point may be less of a problem
    # now that we support derived rules.)
    num_contexts = 2
    options = inputs.GenerationOptions(
        sampling=test_utils.create_sampling_options(
            num_contexts=num_contexts,
            num_requests_per_context=int(num_top_level_examples / num_contexts),
            non_rule_fraction=non_rule_fraction,
            negative_example_fraction=negative_example_fraction))
    counters = outputs.GenerationCounters()

    assert (
        test_utils.get_expected_num_top_level_examples(
            options.sampling) == num_top_level_examples
    ), (f'SamplingOptions should be set in this test to values that yield a '
        f'fixed number of examples of each request type: {options.sampling}')

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    with self.subTest('should_contain_expected_number_of_non_rule_requests'):
      self.assertEqual(
          expected_num_non_rule_requests,
          counters.examples.by_request_type[cl.RequestType.NON_RULE],
          test_utils.get_dataset_summary(counters, dataset))

    num_positive_rule_requests = (
        counters.examples.by_example_type[cl.ExampleType.RULE_KNOWN_TRUE_D] +
        counters.examples.by_example_type[cl.ExampleType.RULE_KNOWN_TRUE_M])
    num_negative_rule_requests = (
        counters.examples.by_example_type[cl.ExampleType.RULE_KNOWN_FALSE_D] +
        counters.examples.by_example_type[cl.ExampleType.RULE_KNOWN_FALSE_M])
    # We will typically get fewer than the expected number of positive/negative
    # rule examples since some of them will have UNKNOWN reply.
    num_unknown_rule_requests = (
        counters.examples.by_example_type[cl.ExampleType.RULE_UNKNOWN_D])

    # Depending on the rounding, the actual counts are allowed to be off by no
    # more than one.
    with self.subTest(
        'should_contain_expected_number_of_positive_rule_requests'):
      self.assertBetween(
          num_positive_rule_requests,
          expected_num_positive_rule_requests - num_unknown_rule_requests - 1,
          expected_num_positive_rule_requests + 1,
          test_utils.get_dataset_summary(counters, dataset))

    with self.subTest(
        'should_contain_expected_number_of_negative_rule_requests'):
      self.assertBetween(
          num_negative_rule_requests,
          expected_num_negative_rule_requests - num_unknown_rule_requests - 1,
          expected_num_negative_rule_requests + 1,
          test_utils.get_dataset_summary(counters, dataset))


if __name__ == '__main__':
  absltest.main()
