# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for outputting the Conceptual SCAN benchmark."""

import errno
import json
import logging
import os
from typing import Dict, Optional, Union

import dataclasses_json
import tensorflow as tf
import tensorflow_datasets as tfds

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.util import io_utils


def _write_examples_as_json(examples,
                            file_path,
                            backup_existing = True):
  """Write the given GroupedExampleSet in JSON format to the specified file."""
  if backup_existing:
    io_utils.back_up_existing_file(file_path)
  logging.info('Writing to file: %s', file_path)
  with tf.io.gfile.GFile(file_path, 'w') as f:
    json.dump(examples.serialize(), f, indent=2)


def counters_from_grouped_example_set(
    dataset):
  """Returns counters with example and context counters populated."""
  counters = outputs.GenerationCounters()
  for example_group in dataset.example_groups:
    counters.contexts.update_with_context(example_group.context)
    counters.rules.update_with_context(example_group.context)
    for context_example in example_group.context:
      counters.context_examples.update_with_example_and_context(context_example)
    for example in example_group:
      counters.examples.update_with_example_and_context(example,
                                                        example_group.context)
  return counters


def _write_object_as_json(obj,
                          file_path,
                          backup_existing = True):
  """Write the given object in JSON format to the specified file.

  Args:
    obj: Object to write. For now, we only support objects implemented with
      @dataclasses_json.dataclass_json.
    file_path: Path of the file to write to.
    backup_existing: Whether to make a backup copy of the file if it already
      exists.
  """
  if backup_existing:
    io_utils.back_up_existing_file(file_path)
  logging.info('Writing to file: %s', file_path)
  with tf.io.gfile.GFile(file_path, 'w') as f:
    f.write(obj.to_json())


def write_dataset_spec(benchmark_dir, spec):
  """Writes the dataset spec to the specified directory.

  Args:
    benchmark_dir: Path of the directory to which to write the dataset.
    spec: Spec containing the options used in the generation of the dataset.
      This will also be written into a file in the same directory.
  """
  if not tf.io.gfile.exists(benchmark_dir):
    logging.info('Directory doesn\'t exist. Creating: %s', benchmark_dir)
    try:
      tf.io.gfile.makedirs(benchmark_dir)
    except OSError as e:
      if e.errno != errno.EEXIST or not tf.io.gfile.isdir(benchmark_dir):
        raise
  _write_object_as_json(spec, os.path.join(benchmark_dir, 'spec.json'))


def get_dataset_path(benchmark_dir):
  """Returns the full path to the dataset JSON file."""
  return os.path.join(benchmark_dir, 'dataset.json')


def get_dataset_counters_path(benchmark_dir):
  """Returns the full path to the dataset counters JSON file."""
  return os.path.join(benchmark_dir, 'dataset_counters.json')


def get_split_path(benchmark_dir, split_type):
  """Returns the full path to the split json file."""
  return os.path.join(benchmark_dir, f'{split_type}.json')


def get_split_counters_path(benchmark_dir, split_type):
  """Returns the full path to the dataset counters JSON file."""
  return os.path.join(benchmark_dir, f'{split_type}_counters.json')


def get_dataset_spec_path(benchmark_dir):
  """Returns the full path to the dataset spec JSON file."""
  return os.path.join(benchmark_dir, 'spec.json')


def get_timing_path(benchmark_dir):
  """Returns the full path to the timing JSON file."""
  return os.path.join(benchmark_dir, 'timing.json')


def get_splitting_stats_path(benchmark_dir):
  """Returns the full path to the splitting stats JSON file."""
  return os.path.join(benchmark_dir, 'splitting_stats.json')


def write_dataset(benchmark_dir, dataset):
  """Writes the dataset to the specified directory.

  Args:
    benchmark_dir: Path of the directory to which to write the dataset.
    dataset: Dataset to write.
  """
  _write_examples_as_json(dataset, get_dataset_path(benchmark_dir))


def write_counters(benchmark_dir,
                   counters):
  """Writes the dataset counters to the specified directory."""
  _write_object_as_json(counters, get_dataset_counters_path(benchmark_dir))


def write_timing(benchmark_dir, timing):
  """Writes the timing information to the specified directory."""
  _write_object_as_json(timing, get_timing_path(benchmark_dir))


def write_splitting_stats(benchmark_dir,
                          splitting_stats):
  """Writes the splitting stats to the specified directory."""
  _write_object_as_json(splitting_stats,
                        get_splitting_stats_path(benchmark_dir))


def write_splits(benchmark_dir,
                 splits):
  """Writes the splits to the specified directory.

  Args:
    benchmark_dir: Path of the directory to which to write the dataset.
    splits: Mapping of split ID (Train, Test) to the split contents. These will
      also be written into files in the same directory, with names matching the
      split IDs.
  """
  for split_type, split_content in splits.items():
    _write_object_as_json(
        counters_from_grouped_example_set(split_content),
        get_split_counters_path(benchmark_dir, split_type))
    _write_examples_as_json(split_content,
                            get_split_path(benchmark_dir, split_type))


def read_dataset(benchmark_dir):
  """Returns a dataset read from the specified directory."""
  with tf.io.gfile.GFile(get_dataset_path(benchmark_dir), 'r') as reader:
    file_contents = reader.read()

  return cl.GroupedExampleSet.deserialize(json.loads(file_contents))


def read_split(benchmark_dir,
               split_type):
  """Returns a split read from the specified directory or None if not found."""
  split_file_path = os.path.join(benchmark_dir, f'{split_type}.json')
  if not tf.io.gfile.exists(split_file_path):
    return None
  logging.info('Reading split file: %s', split_file_path)
  with tf.io.gfile.GFile(split_file_path, 'r') as reader:
    file_contents = reader.read()
  logging.info('Loading JSON: %s', split_file_path)
  json_contents = json.loads(file_contents)
  logging.info('Converting JSON to GroupedExampleSet: %s', split_file_path)
  grouped_example_set = cl.GroupedExampleSet.deserialize(json_contents)
  logging.info('Finished reading split: %s', split_file_path)
  return grouped_example_set


def read_splits(benchmark_dir):
  """Returns splits read from the specified directory."""
  splits = {}
  for split_type in (tfds.Split.TRAIN, tfds.Split.TEST, tfds.Split.VALIDATION):
    split = read_split(benchmark_dir, split_type)
    if split:
      splits[split_type] = split
  return splits


def read_dataset_counters(benchmark_dir):
  """Returns a counters object read from the specified directory, if present."""
  file_path = get_dataset_counters_path(benchmark_dir)
  if tf.io.gfile.exists(file_path):
    with tf.io.gfile.GFile(file_path, 'r') as reader:
      file_contents = reader.read()
    return outputs.GenerationCounters.from_json(file_contents)
  else:
    return outputs.GenerationCounters()


def read_dataset_spec(benchmark_dir):
  """Returns the dataset spec object read from benchmark_dir, if present."""
  file_path = get_dataset_spec_path(benchmark_dir)
  if tf.io.gfile.exists(file_path):
    with tf.io.gfile.GFile(file_path, 'r') as reader:
      file_contents = reader.read()
    return inputs.DatasetSpec.from_json(file_contents)
  else:
    raise ValueError(f'File: {file_path} does not exist.')


def read_timing(benchmark_dir):
  """Returns a timing object read from the specified directory, if present."""
  file_path = get_timing_path(benchmark_dir)
  if tf.io.gfile.exists(file_path):
    with tf.io.gfile.GFile(file_path, 'r') as reader:
      file_contents = reader.read()
    return outputs.GenerationTiming.from_json(file_contents)
  else:
    return outputs.GenerationTiming()


def read_splitting_stats(benchmark_dir):
  """Returns a splitting stats object read from the directory, if present."""
  file_path = get_splitting_stats_path(benchmark_dir)
  if tf.io.gfile.exists(file_path):
    with tf.io.gfile.GFile(file_path, 'r') as reader:
      file_contents = reader.read()
    return outputs.SplittingStats.from_json(file_contents)
  else:
    return outputs.SplittingStats()


def read_stats(benchmark_dir):
  """Returns a stats object, read from the specified directory if present."""
  return outputs.GenerationStats(
      counters=read_dataset_counters(benchmark_dir),
      timing=read_timing(benchmark_dir),
      splitting_stats=read_splitting_stats(benchmark_dir))
