# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Superclass of benchmark_generation_small_scale_generation_test_*."""
import dataclasses
import traceback

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.cscan import benchmark_generation as bg
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan.specs import dataset_specs


class BenchmarkGenerationSmallScaleGenerationTest(parameterized.TestCase):
  """Base class for benchmark_generation_small_scale_generation tests."""


  def _do_test_dataset_specs_shows_no_issue_in_small_scale_generation(
      self, shard_index, total_shards):
    """Tests the specified shard of the registered dataset specs.

    E.g., if shard_index = 0 and total_shards = 3, then will test just the
    dataset specs whose index % total_shards == 0.

    This is a technique to enable this rather time-consuming test to run faster
    via parallelization across multiple test files.

    Args:
      shard_index: Zero-based shard index indicating which datasets to test.
      total_shards: Total number of shards to be run including this one.
    """
    requested_num_top_level_examples = 20
    generation_error_by_spec_id = {}
    generation_counter_issues_by_spec_id = {}

    sorted_spec_id_and_dataset_spec = sorted(
        dataset_specs.DATASET_SPEC_BY_ID.items(),
        key=lambda id_spec: id_spec[0])
    for index, (spec_id,
                dataset_spec) in enumerate(sorted_spec_id_and_dataset_spec):
      if index % total_shards != shard_index:
        continue
      dataset_spec_for_small_scale_run = dataclasses.replace(
          dataset_spec,
          generation_options=dataclasses.replace(
              dataset_spec.generation_options,
              sampling=dataclasses.replace(
                  dataset_spec.generation_options.sampling,
                  num_contexts=1,
                  num_requests_per_context=requested_num_top_level_examples,
                  derived_production_yield_probability=0.1)))
      stats = outputs.GenerationStats()

      try:
        stats = bg.generate_benchmark_from_spec(
            benchmark_dir=self.create_tempdir().full_path,
            should_overwrite=True,
            dataset_spec=dataset_spec_for_small_scale_run)
      except Exception as e:
        logging.warning('Error in dataset spec: %s', spec_id)
        traceback.print_exc()
        generation_error_by_spec_id[spec_id] = e

      num_top_level_examples = stats.counters.examples.get_total()
      if num_top_level_examples != requested_num_top_level_examples:
        generation_counter_issues_by_spec_id[spec_id] = (
            f'Generated wrong number of examples. Should have been '
            f'{requested_num_top_level_examples} but was '
            f'{num_top_level_examples}. Full counters: {stats.counters}')

    with self.subTest('should_raise_no_error'):
      self.assertEmpty(
          generation_error_by_spec_id,
          'This is a collections of dataset specs that raised errors during '
          'generation, together with the specific error. This collection '
          'should have been empty!')

    with self.subTest('resulting_counters_should_indicate_success'):
      self.assertEmpty(
          generation_counter_issues_by_spec_id,
          'This is a collections of dataset specs that showed anomalies in the '
          'counters output from generation, together with the specific issue. '
          'This collection should have been empty!')


if __name__ == '__main__':
  absltest.main()
