# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import traceback
from typing import Iterable, Mapping

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_spec_loader
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import induction
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import test_utils


class SamplingOptionsTest(parameterized.TestCase):

  def test_round_while_preserving_sum_returns_int_type(self):
    result = inputs._round_while_preserving_sum([0.67, 0.67, 0.66])
    # Rather than 'numpy.int64', 'float', etc.
    self.assertEqual(type(result[0]), type(1))

  def test_round_while_preserving_sum_adjust_down(self):
    # Rather than [1, 1, 1]. Maintains total of 2.
    self.assertEqual(
        inputs._round_while_preserving_sum([0.67, 0.67, 0.66]), [1, 0, 1])

  def test_round_while_preserving_sum_adjust_up(self):
    # Rather than [0, 0, 0]. Maintains total of 1.
    self.assertEqual(
        inputs._round_while_preserving_sum([0.33, 0.33, 0.34]), [0, 1, 0])

  def test_round_while_preserving_sum_total_rounded_down(self):
    # The total is rounded from 3.33 down to 3.
    self.assertEqual(
        inputs._round_while_preserving_sum([1.33, 1.0, 1.0]), [1, 1, 1])

  def test_round_while_preserving_sum_total_rounded_up(self):
    # The total is rounded from 3.66 down to 4.
    self.assertEqual(
        inputs._round_while_preserving_sum([1.33, 1.33, 1.0]), [1, 2, 1])

  @parameterized.named_parameters(('single_context_one_non_rule', 1, 1, 0.6, [{
      cl.RequestType.NON_RULE: 1,
      cl.RequestType.RULE: 0
  }]), ('single_context_standard_rounding', 1, 20, 0.5, [{
      cl.RequestType.NON_RULE: 10,
      cl.RequestType.RULE: 10
  }]), ('single_context_adjusted_rounding', 1, 9, 0.5, [{
      cl.RequestType.NON_RULE: 4,
      cl.RequestType.RULE: 5
  }]), ('multiple_contexts_adjusted_rounding', 2, 9, 0.5, [{
      cl.RequestType.NON_RULE: 4,
      cl.RequestType.RULE: 5
  }, {
      cl.RequestType.NON_RULE: 5,
      cl.RequestType.RULE: 4
  }]))
  def test_calculate_schedule_of_examples_by_type(
      self, num_contexts, num_requests_per_context,
      non_rule_fraction,
      expected_result):
    options = inputs.SamplingOptions(
        num_contexts=num_contexts,
        num_requests_per_context=num_requests_per_context,
        non_rule_fraction=non_rule_fraction)
    actual_result = list(
        itertools.islice(options.calculate_schedule_of_examples_by_type(),
                         num_contexts))
    self.assertEqual(actual_result, expected_result)


class GenerationOptionsTest(absltest.TestCase):

  def test_is_json_serializable(self):
    options = inputs.GenerationOptions()
    try:
      options_as_json = options.to_json()
    except TypeError:
      self.fail(f'Exception raised when converting GenerationOptions to JSON: '
                f'{traceback.format_exc()}')
    logging.info('Original GenerationOptions:\n%s', options)
    logging.info('GenerationOptions as JSON:\n%s', options_as_json)

  def test_rule_format_is_a_human_readable_string_in_json(self):
    options = inputs.SamplingOptions(
        rule_format=enums.RuleFormat.INTERPRETATION_RULE)
    options_as_json = options.to_json()
    self.assertRegex(options_as_json, '"rule_format": "INTERPRETATION_RULE"')

  def test_split_by_is_a_human_readable_string_in_json(self):
    options = inputs.SplitOptions(split_by=inputs.SplitBy.CONTEXT)
    options_as_json = options.to_json()
    self.assertRegex(options_as_json, '"split_by": "CONTEXT"')


class DatasetSpecTest(absltest.TestCase):

  def test_json_serialization_roundtrip(self):
    dataset_spec = dataset_spec_loader.load_dataset_spec(
        test_utils.TEST_DATASET_SPEC_ID)
    dataset_spec_as_json = dataset_spec.to_json()
    recovered_dataset_spec = inputs.DatasetSpec.from_json(dataset_spec_as_json)
    logging.info('Original DatasetSpec:\n%s', dataset_spec)
    logging.info('DatasetSpec as JSON:\n%s', dataset_spec_as_json)
    logging.info('Recovered DatasetSpec:\n%s', recovered_dataset_spec)

    self.assertEqual(recovered_dataset_spec, dataset_spec)

  def test_json_serialization_roundtrip_with_non_default_inductive_bias(self):
    dataset_spec = inputs.DatasetSpec(
        id='test',
        description='Ad-hoc DatasetSpec for serialization test.',
        generation_options=inputs.GenerationOptions(
            sampling=inputs.SamplingOptions(
                inductive_bias=induction.IllustrativeSubstitutionsInductiveBias(
                    min_illustrative_variable_substitutions=5))))
    dataset_spec_as_json = dataset_spec.to_json()
    recovered_dataset_spec = inputs.DatasetSpec.from_json(dataset_spec_as_json)
    logging.info('Original DatasetSpec:\n%s', dataset_spec)
    logging.info('DatasetSpec as JSON:\n%s', dataset_spec_as_json)
    logging.info('Recovered DatasetSpec:\n%s', recovered_dataset_spec)

    self.assertEqual(recovered_dataset_spec, dataset_spec)


if __name__ == '__main__':
  absltest.main()
