# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the dataset_generation module concerning counters."""

from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_generation
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import test_utils


class DatasetGenerationCountersTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_generate_dataset_should_populate_examples_counters_by_example_type(
      self, unused_mock):
    options = inputs.GenerationOptions(
        grammar=inputs.GrammarOptions(
            num_primitives=3,
            num_precedence_levels=3,
            max_repetitions_per_token_in_output_sequence=1),
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            omitted_fraction=0.1,
            explicit_fraction=0.6,
            non_rule_fraction=0.2,
            negative_example_fraction=0.5))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)
    flat_dataset = dataset.to_example_set()

    with self.subTest('should_contain_correct_total_number_of_examples'):
      self.assertLen(flat_dataset, counters.examples.get_total())

    for example_type in cl.ExampleType._member_map_:  # pytype: disable=attribute-error
      with self.subTest(f'should_contain_correct_number_of_{example_type}'):
        examples_of_example_type = [
            example for example in flat_dataset
            if example.get_example_type() == example_type
        ]
        self.assertLen(examples_of_example_type,
                       counters.examples.by_example_type[example_type])

  def test_generate_dataset_should_populate_counters_already_in_context(self):
    num_contexts = 2
    num_requests_per_context = 3
    max_attempts_per_example = 20
    # Here we generate very small grammars from which only one unique
    # example can be generated from each.  This guarantees that all attempts
    # to generate top-level examples would be already in the context.
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=1,
            num_precedence_levels=1,
            num_functions_per_level=1,
            has_pass_through_rules=False,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=num_contexts,
            num_requests_per_context=num_requests_per_context,
            max_attempts_per_example=max_attempts_per_example,
            explicit_fraction=0.0,
            non_rule_fraction=1.0))
    counters = outputs.GenerationCounters()

    unused_dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    # We sample productions from the inference engine when generating examples
    # with known reply, and in this case every distinct example is only yielded
    # once.
    expected_num_already_in_context = num_contexts
    self.assertEqual(counters.example_attempts.already_in_context,
                     expected_num_already_in_context)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine(
          never_contains=True, all_consistent=True))
  def test_generate_dataset_should_set_unknown_reply_and_original_reply(
      self, unused_mock):
    # We use a fake inference engine that does not recognize any production,
    # and considers everything consistent so that all the top-level examples
    # should have UNKNOWN reply, with original reply set, and counted in the
    # knownness counter.
    options = inputs.GenerationOptions(
        grammar=test_utils.create_grammar_options_with_fixed_number_of_rules(
            num_primitives=5,
            num_precedence_levels=3,
            num_functions_per_level=2,
            has_pass_through_rules=True,
            has_concat_rule=False),
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            num_requests_per_context=10,
            omitted_fraction=0.1,
            unreliable_fraction=0.0,
            explicit_fraction=0.2,
            non_rule_fraction=0.1))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)
    flat_dataset = dataset.to_example_set()
    examples_with_unknown_reply = [
        example for example in flat_dataset
        if example.reply == cl.RuleReply.UNKNOWN
    ]

    with self.subTest('should_set_unknown_reply_correctly'):
      self.assertCountEqual(flat_dataset, examples_with_unknown_reply)

    with self.subTest('should_update_by_knownness_counter'):
      self.assertLen(flat_dataset, counters.examples.by_knownness['UNKNOWN'])


if __name__ == '__main__':
  absltest.main()
