# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import os

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_datasets as tfds

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_io
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import test_utils


def _create_grouped_example_set():
  context = cl.FrozenExampleSet.from_examples(
      [cl.Example(request='c', reply='d')])
  example_set = cl.ExampleSet.from_examples(
      [cl.Example(context=context, request='e', reply='f')])
  return cl.GroupedExampleSet.from_example_set(example_set)


def _create_stats(generate_dataset_timing = 1.0):
  return outputs.GenerationStats(
      counters=outputs.GenerationCounters(
          example_attempts=outputs.ExampleAttemptCounters(duplicate=1, valid=2),
          examples=outputs.ExampleCounters(
              by_request_type=collections.defaultdict(
                  int, {cl.RequestType.NON_RULE: 3}),
              by_example_type=collections.defaultdict(
                  int, {cl.ExampleType.NONRULE_KNOWN_D: 4}),
              by_qualifier=collections.defaultdict(int, {cl.Qualifier.M: 5})),
          errors=outputs.GenerationErrorCounters(
              failed_to_illustrate_target_rule=1)),
      timing=outputs.GenerationTiming(generate_dataset=generate_dataset_timing))


class DatasetIoTest(parameterized.TestCase):

  def assertFileExists(self, filename):
    self.assertTrue(
        tf.io.gfile.exists(filename),
        f'File expected to exist but does not: {filename}')

  def test_write_dataset(self):
    benchmark_dir = self.create_tempdir().full_path
    dataset = _create_grouped_example_set()

    dataset_io.write_dataset(benchmark_dir=benchmark_dir, dataset=dataset)

    dataset_filepath = os.path.join(benchmark_dir, 'dataset.json')
    dataset_string = test_utils.maybe_read_file(dataset_filepath)
    logging.info('Original dataset:\n%s', dataset)
    logging.info('Dataset from JSON file:\n%s', dataset_string)

    with self.subTest('should_output_full_dataset_in_json_format'):
      self.assertFileExists(dataset_filepath)

    with self.subTest('dataset_file_should_not_be_empty'):
      # Note: The conversion of ExampleSets to JSON is tested in more detail in
      # the tests for ExampleSet.serialize in conceptual_learning_test.py.
      self.assertNotEmpty(dataset_string)

  def test_write_counters(self):
    benchmark_dir = self.create_tempdir().full_path
    counters = _create_stats().counters

    dataset_io.write_counters(benchmark_dir, counters)

    counters_filepath = os.path.join(benchmark_dir, 'dataset_counters.json')
    counters_string = test_utils.maybe_read_file(counters_filepath)

    logging.info('Original counters:\n%s', counters)
    logging.info('Counters from JSON file:\n%s', counters_string)

    with self.subTest('should_output_counters'):
      self.assertFileExists(counters_filepath)

    with self.subTest('counters_file_should_not_be_empty'):
      self.assertNotEmpty(counters_string)

  def test_write_timing(self):
    benchmark_dir = self.create_tempdir().full_path
    timing = _create_stats().timing

    dataset_io.write_timing(benchmark_dir, timing)

    timing_filepath = os.path.join(benchmark_dir, 'timing.json')
    timing_string = test_utils.maybe_read_file(timing_filepath)
    logging.info('Original timing:\n%s', timing)
    logging.info('Timing from JSON file:\n%s', timing_string)

    with self.subTest('should_output_timing'):
      self.assertFileExists(timing_filepath)

    with self.subTest('timing_file_should_not_be_empty'):
      self.assertNotEmpty(timing_string)

  def test_write_stats_multiple_times(self):
    benchmark_dir = self.create_tempdir().full_path
    stats1 = _create_stats(generate_dataset_timing=1.0)
    stats2 = _create_stats(generate_dataset_timing=2.0)
    stats3 = _create_stats(generate_dataset_timing=3.0)

    dataset_io.write_counters(benchmark_dir, stats1.counters)
    dataset_io.write_counters(benchmark_dir, stats2.counters)
    dataset_io.write_counters(benchmark_dir, stats3.counters)
    dataset_io.write_timing(benchmark_dir, stats1.timing)
    dataset_io.write_timing(benchmark_dir, stats2.timing)
    dataset_io.write_timing(benchmark_dir, stats3.timing)

    benchmark_dir_content = tf.io.gfile.listdir(benchmark_dir)
    logging.info('benchmark_dir content: %s', benchmark_dir_content)

    timing_string = test_utils.maybe_read_file(
        os.path.join(benchmark_dir, 'timing.json'))
    timing_string_1 = test_utils.maybe_read_file(
        os.path.join(benchmark_dir, 'timing_1.json'))
    timing_string_2 = test_utils.maybe_read_file(
        os.path.join(benchmark_dir, 'timing_2.json'))

    with self.subTest('should_leave_backup_copies_of_overwritten_files'):
      self.assertCountEqual(('dataset_counters.json', 'dataset_counters_1.json',
                             'dataset_counters_2.json', 'timing.json',
                             'timing_1.json', 'timing_2.json'),
                            benchmark_dir_content)

    with self.subTest('main_file_should_contain_the_second_newest_content'):
      self.assertContainsExactSubsequence(timing_string,
                                          '"generate_dataset": 3.0')

    with self.subTest('first_backup_file_should_contain_the_oldest_content'):
      self.assertContainsExactSubsequence(timing_string_1,
                                          '"generate_dataset": 1.0')

    with self.subTest(
        'second_backup_file_should_contain_the_second_oldest_content'):
      self.assertContainsExactSubsequence(timing_string_2,
                                          '"generate_dataset": 2.0')

  def test_write_splits(self):
    benchmark_dir = self.create_tempdir().full_path
    splits = {
        tfds.Split.TRAIN: _create_grouped_example_set(),
        tfds.Split.TEST: cl.GroupedExampleSet()
    }

    dataset_io.write_splits(benchmark_dir=benchmark_dir, splits=splits)

    train_counters_filepath = os.path.join(benchmark_dir, 'train_counters.json')
    test_counters_filepath = os.path.join(benchmark_dir, 'test_counters.json')
    train_counters_string = test_utils.maybe_read_file(train_counters_filepath)
    test_counters_string = test_utils.maybe_read_file(test_counters_filepath)

    with self.subTest('should_output_train_test_splits'):
      self.assertFileExists(os.path.join(benchmark_dir, 'train.json'))
      self.assertTrue(
          tf.io.gfile.exists(os.path.join(benchmark_dir, 'test.json')))

    with self.subTest('should_output_counters'):
      self.assertFileExists(train_counters_filepath)
      self.assertFileExists(test_counters_filepath)

    with self.subTest('counters_files_should_not_be_empty'):
      self.assertNotEmpty(train_counters_string)
      self.assertNotEmpty(test_counters_string)

  def test_write_read_and_rewrite_splits(self):
    benchmark_dir = self.create_tempdir().full_path
    splits = {
        tfds.Split.TRAIN: _create_grouped_example_set(),
        tfds.Split.TEST: cl.GroupedExampleSet()
    }

    dataset_io.write_splits(benchmark_dir=benchmark_dir, splits=splits)
    restored_splits = dataset_io.read_splits(benchmark_dir)

    logging.info('Original splits:\n%s', splits)
    logging.info('Splits from JSON file:\n%s', restored_splits)

    dataset_io.write_splits(benchmark_dir=benchmark_dir, splits=restored_splits)

    benchmark_dir_content = tf.io.gfile.listdir(benchmark_dir)
    logging.info('benchmark_dir content: %s', benchmark_dir_content)

    with self.subTest('write_read_roundtrip_should_restore_original_contents'):
      self.assertEqual(splits, restored_splits)

    with self.subTest('should_leave_backup_copies_of_overwritten_files'):
      self.assertCountEqual(
          ('train.json', 'train_counters.json', 'test.json',
           'test_counters.json', 'train_1.json', 'train_counters_1.json',
           'test_1.json', 'test_counters_1.json'), benchmark_dir_content)

  @parameterized.named_parameters(('benchmark_dir_exists', True),
                                  ('benchmark_dir_does_not_yet_exist', False))
  def test_write_read_dataset_spec(self, benchmark_dir_already_exists):
    if benchmark_dir_already_exists:
      benchmark_dir = self.create_tempdir()
    else:
      benchmark_dir = os.path.join(self.create_tempdir().full_path, 'subdir')
    spec = inputs.DatasetSpec(id='test', description='Test dataset spec.')

    dataset_io.write_dataset_spec(benchmark_dir, spec)

    spec_filepath = os.path.join(benchmark_dir, 'spec.json')
    spec_string = test_utils.maybe_read_file(spec_filepath)
    logging.info('Original spec:\n%s', spec)
    logging.info('Spec from JSON file:\n%s', spec_string)

    with self.subTest('should_create_directory_if_necessary'):
      self.assertFileExists(spec_filepath)

    with self.subTest('spec_file_should_not_be_empty'):
      self.assertNotEmpty(spec_string)

    restored_spec = dataset_io.read_dataset_spec(benchmark_dir)
    with self.subTest('write_read_roundtrip_should_restore_original_content'):
      self.assertEqual(spec, restored_spec)


if __name__ == '__main__':
  absltest.main()
