# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import benchmark_generation as bg
from conceptual_learning.cscan import dataset_spec_loader
from conceptual_learning.cscan import test_utils
from conceptual_learning.cscan.specs import dataset_specs
from conceptual_learning.cscan.specs import dataset_suite_specs


class DatasetSpecLoaderTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('valid_dataset_spec_id', test_utils.TEST_DATASET_SPEC_ID, True),
      ('dataset_suite_spec_id', test_utils.TEST_DATASET_SUITE_SPEC_ID, False),
      ('invalid_id', test_utils.INVALID_DATASET_SPEC_ID, False))
  def test_is_valid_dataset_spec_id(self, spec_id, expected_result):
    self.assertEqual(
        dataset_spec_loader.is_valid_dataset_spec_id(spec_id), expected_result)

  def test_load_dataset_spec_succeeds_with_valid_id(self):
    dataset_spec = dataset_spec_loader.load_dataset_spec(
        test_utils.TEST_DATASET_SPEC_ID)
    self.assertEqual(dataset_spec.id, test_utils.TEST_DATASET_SPEC_ID)

  def test_load_dataset_spec_raises_error_with_invalid_id(self):
    with self.assertRaisesRegex(ValueError, 'Invalid dataset spec id'):
      _ = dataset_spec_loader.load_dataset_spec(
          test_utils.INVALID_DATASET_SPEC_ID)

  def test_load_dataset_suite_spec_succeeds_with_valid_id(self):
    dataset_suite_spec = dataset_spec_loader.load_dataset_suite_spec(
        test_utils.TEST_DATASET_SUITE_SPEC_ID)
    self.assertEqual(dataset_suite_spec.id,
                     test_utils.TEST_DATASET_SUITE_SPEC_ID)

  def test_load_dataset_suite_spec_raises_error_with_invalid_id(self):
    with self.assertRaisesRegex(ValueError, 'Invalid dataset suite spec id'):
      _ = dataset_spec_loader.load_dataset_suite_spec(
          test_utils.INVALID_DATASET_SPEC_ID)

  def test_dataset_suite_specs_contain_only_valid_spec_ids(self):
    invalid_dataset_spec_ids_by_suite_id = collections.defaultdict(list)
    for suite_id, suite_spec in (
        dataset_suite_specs.DATASET_SUITE_SPEC_BY_ID.items()):
      for dataset_spec_id in suite_spec.dataset_specs:
        if not dataset_spec_loader.is_valid_dataset_spec_id(dataset_spec_id):
          invalid_dataset_spec_ids_by_suite_id[suite_id].append(dataset_spec_id)

    self.assertEmpty(
        invalid_dataset_spec_ids_by_suite_id,
        'This is a collection of dataset suites that reference invalid dataset '
        'spec IDs, together with the problematic dataset spec IDs for each. '
        'This collection should have been empty!')

  def test_dataset_specs_contain_only_valid_grammar_templates(self):
    template_error_by_spec_id = {}
    for spec_id, dataset_spec in dataset_specs.DATASET_SPEC_BY_ID.items():
      if not dataset_spec.template_grammar_id:
        continue
      try:
        _ = bg.load_fixed_phrase_structure_grammar_template_for_dataset_spec(
            dataset_spec)
      except ValueError as e:
        template_error_by_spec_id[spec_id] = e

    self.assertEmpty(
        template_error_by_spec_id,
        'This is a collections of dataset specs that have problems with the '
        'grammar template that they reference, together with the specific '
        'error. This collection should have been empty!')

  def test_dataset_spec_random_seeds_are_unique_unless_intended_otherwise(self):
    # This ensures that different dataset specs always yield materially
    # different random content, however similar their GenerationOptions may
    # otherwise be.
    spec_ids_by_seed = collections.defaultdict(set)
    for spec_id, dataset_spec in dataset_specs.DATASET_SPEC_BY_ID.items():
      random_seed = dataset_spec.generation_options.random_seed
      if dataset_spec.generation_options.random_seed_same_as:
        spec_ids_by_seed[random_seed].add(
            dataset_spec.generation_options.random_seed_same_as)
      else:
        spec_ids_by_seed[random_seed].add(spec_id)
    max_specs_per_seed = max(len(ids) for ids in spec_ids_by_seed.values())
    self.assertEqual(max_specs_per_seed, 1,
                     f'spec_ids_by_seed = {spec_ids_by_seed}')


if __name__ == '__main__':
  absltest.main()
