# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the dataset_generation module concerning input/output lengths."""

import traceback
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import dataset_generation
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import test_utils


class DatasetGenerationLengthsTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_should_populate_input_output_lengths(self, unused_mock):
    options = inputs.GenerationOptions(
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            omitted_fraction=0.1,
            explicit_fraction=0.6,
            non_rule_fraction=0.2,
            negative_example_fraction=0.5))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)
    flat_dataset = dataset.to_example_set()

    # Note that only "standard" input and output lengths are populated by
    # default in unit tests, as the T5X tokenizer used in calculating the
    # "compact" input and output lengths would require a remote service
    # dependency.

    with self.subTest('example_metadata_should_contain_input_length'):
      self.assertGreater(flat_dataset[0].metadata.input_length_standard, 0)

    with self.subTest('example_metadata_should_contain_output_length'):
      self.assertGreater(flat_dataset[0].metadata.output_length_standard, 0)

    with self.subTest('all_examples_be_counted_in_input_length_stats'):
      self.assertLen(flat_dataset,
                     counters.examples.input_length_stats_standard.count)

    with self.subTest('all_examples_be_counted_in_output_length_stats'):
      self.assertLen(flat_dataset,
                     counters.examples.output_length_stats_standard.count)

    with self.subTest('input_length_should_include_both_context_and_request'):
      # The only way the input length could be less than about 20 would be if
      # the context were not getting counted.
      self.assertGreaterEqual(counters.examples.input_length_stats_standard.min,
                              20)

    with self.subTest('output_length_should_include_both_reply_and_qualifier'):
      # At minimum, there must be at least one token for the reply and one token
      # for the qualifier.
      self.assertGreaterEqual(
          counters.examples.output_length_stats_standard.min, 2)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_should_populate_counters_exceeded_max_input_length(
      self, unused_mock):
    options = inputs.GenerationOptions(
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            num_requests_per_context=3,
            max_attempts_per_example=10,
            max_attempts_per_context=2,
            non_rule_fraction=1.0,
            # Here we configure a max_input_length that is so small that we know
            # that every generated example will exceed the maximum length.
            max_input_length_standard=1))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    with self.subTest('should_generate_no_top_level_examples'):
      self.assertEmpty(dataset.to_example_set(),
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('should_count_failed_context_attempts'):
      self.assertGreater(counters.context_attempts.exceeded_max_input_length, 0)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_should_populate_counters_exceeded_max_output_length(
      self, unused_mock):
    options = inputs.GenerationOptions(
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            num_requests_per_context=3,
            max_attempts_per_example=10,
            max_attempts_per_context=2,
            non_rule_fraction=1.0,
            # Here we configure a max_output_length that is so small that we
            # know that every generated example will exceed the maximum length.
            max_output_length_standard=1))
    counters = outputs.GenerationCounters()

    dataset = dataset_generation.generate_dataset(
        options=options, counters=counters, rng=self.rng)

    with self.subTest('should_generate_no_top_level_examples'):
      self.assertEmpty(dataset.to_example_set(),
                       test_utils.get_dataset_summary(counters, dataset))

    with self.subTest('should_count_failed_example_attempts'):
      self.assertGreater(counters.example_attempts.exceeded_max_output_length,
                         0)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_should_not_throw_errors_when_max_input_output_length_compact_is_set(
      self, unused_mock):
    options = inputs.GenerationOptions(
        sampling=test_utils.create_sampling_options(
            num_contexts=1,
            num_requests_per_context=3,
            max_attempts_per_example=10,
            max_attempts_per_context=2,
            non_rule_fraction=1.0,
            max_input_length_compact=1,
            max_output_length_compact=1))
    counters = outputs.GenerationCounters()

    try:
      unused_dataset = dataset_generation.generate_dataset(
          options=options, counters=counters, rng=self.rng)
    except Exception:
      self.fail(f'Exception raised in generate_dataset:\n'
                f'{traceback.format_exc()}')


if __name__ == '__main__':
  absltest.main()
