# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from conceptual_learning.cscan import grammar_generation
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import test_utils


class ExpectedNumRulesPerGrammarTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('each_primitive_represents_one_rule',
       test_utils.create_grammar_options_with_fixed_number_of_rules(
           num_primitives=5,
           num_precedence_levels=0,
           num_functions_per_level=0,
           has_pass_through_rules=False,
           has_concat_rule=False), 5),
      ('num_functions_per_level_should_apply_to_each_precedence_level',
       test_utils.create_grammar_options_with_fixed_number_of_rules(
           num_primitives=5,
           num_precedence_levels=2,
           num_functions_per_level=3,
           has_pass_through_rules=False,
           has_concat_rule=False), 5 + 2 * 3),
      ('pass_through_rules_apply_to_each_precedence_level',
       test_utils.create_grammar_options_with_fixed_number_of_rules(
           num_primitives=5,
           num_precedence_levels=2,
           num_functions_per_level=3,
           has_pass_through_rules=True,
           has_concat_rule=False), 5 + 2 * (3 + 1)),
      ('should_have_at_most_one_concat_rule_in_the_entire_grammar',
       test_utils.create_grammar_options_with_fixed_number_of_rules(
           num_primitives=5,
           num_precedence_levels=2,
           num_functions_per_level=3,
           has_pass_through_rules=False,
           has_concat_rule=True), 5 + 2 * 3 + 1),
  )
  def test_with_fixed_number_of_rules(self, options, expected_num_rules):
    with self.subTest('should_return_the_expected_value'):
      self.assertEqual(expected_num_rules,
                       test_utils.get_expected_num_rules_per_grammar(options))
    with self.subTest(
        'should_match_the_actual_number_of_rules_in_the_constructed_grammar'):
      grammar_generator = grammar_generation.GrammarGenerator(
          options=options, rng=self.rng)
      grammar_schema = grammar_generator.generate_grammar_schema()
      self.assertLen(list(grammar_schema.get_all_rules()), expected_num_rules)

  @parameterized.named_parameters(
      ('unpredictable_number_of_primitives',
       dataclasses.replace(
           test_utils.create_grammar_options_with_fixed_number_of_rules(),
           num_primitives=None,
           min_num_primitives=2,
           max_num_primitives=3)),
      ('unpredictable_number_of_precedence_levels',
       dataclasses.replace(
           test_utils.create_grammar_options_with_fixed_number_of_rules(),
           num_precedence_levels=None,
           min_num_precedence_levels=2,
           max_num_precedence_levels=3)),
      ('unpredictable_number_of_functions_per_level',
       dataclasses.replace(
           test_utils.create_grammar_options_with_fixed_number_of_rules(),
           min_num_functions_per_level=2,
           max_num_functions_per_level=3)),
      ('unpredictable_number_of_pass_through_rules',
       dataclasses.replace(
           test_utils.create_grammar_options_with_fixed_number_of_rules(),
           prob_pass_through_rule=0.5)),
      ('unpredictable_number_of_concat_rules',
       dataclasses.replace(
           test_utils.create_grammar_options_with_fixed_number_of_rules(),
           prob_concat_rule=0.5)),
  )
  def test_should_raise_error_if_number_of_rules_is_not_fixed(
      self, grammar_options):
    with self.assertRaisesRegex(
        ValueError,
        'Only defined if the number of rules in the grammar is fixed'):
      test_utils.get_expected_num_rules_per_grammar(grammar_options)


class ExpectedNumTopLevelExamplesTest(parameterized.TestCase):

  def test_should_include_the_configured_number_of_examples_per_context(self):
    self.assertEqual(
        6,
        test_utils.get_expected_num_top_level_examples(
            inputs.SamplingOptions(num_contexts=2, num_requests_per_context=3)))

  @parameterized.named_parameters(
      ('zero_num_context', 0, 3, 0), ('zero_num_requests_per_context', 2, 0, 0),
      ('one_input_negative', 2, -3, -6), ('both_inputs_negative', -2, -3, 6))
  def test_returns_reasonable_value_in_edge_cases_of_marginal_interest(
      self, num_contexts, num_requests_per_context,
      expected_num_top_level_examples):
    # Note: There is no foreseen use case for calling this method with any of
    # these values. This is just documentation of the implemented behavior.
    self.assertEqual(
        expected_num_top_level_examples,
        test_utils.get_expected_num_top_level_examples(
            inputs.SamplingOptions(
                num_contexts=num_contexts,
                num_requests_per_context=num_requests_per_context)))


if __name__ == '__main__':
  absltest.main()
