# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Command line tool to generate a single Conceptual SCAN benchmark.

"""

from absl import app
from absl import flags

from conceptual_learning.cscan import benchmark_generation
from conceptual_learning.cscan import directory_names


_BENCHMARK_SUITE_DIR = flags.DEFINE_string(
    'benchmark_base_dir',
    default=None,
    help=('Directory under which a benchmark directory will be created if '
          'benchmark_dir is not directly specified. Ignored if benchmark_dir '
          'is specified.'))

_BENCHMARK_DIR = flags.DEFINE_string(
    'benchmark_dir',
    default=None,
    help=('Full path to the directory to which the benchmark will be written. '
          'If not specified, then a benchmark directory will automatically be '
          'created under benchmark_suite_dir, named with the dataset spec ID '
          'and an index number.'))

_OVERWRITE = flags.DEFINE_bool(
    'overwrite',
    default=True,
    help=('Whether to overwrite files that already exist.'))


_DATASET_SPEC = flags.DEFINE_string(
    'dataset_spec',
    default=None,
    help=('Id of the DatasetSpec to use. Will automatically search for a spec '
          'with the given id in `data/dataset_specs.json`.'))

_REPLICA_INDEX = flags.DEFINE_integer(
    'replica_index',
    default=1,
    help=('Dataset replica index (starting from 1), in the case where multiple '
          'different random datasets are to be generated for the same spec. '
          'In that case, a separate random seed will be generated for each '
          'replica based on the replica_index and the original random_seed '
          'specified in the DatasetSpec. If, the replica_index is set to 0, '
          'then the original random_seed specified in the DatasetSpec will be '
          'used as-is (useful mainly for debugging purposes).'),
    lower_bound=0)

_DETERMINISTIC = flags.DEFINE_bool(
    'deterministic',
    default=True,
    help=('If True, then will use the fixed random seed defined in the '
          'DatasetSpec, so as to guarantee that repeated generation runs with '
          'the same spec will yield the same results. If False, then will use '
          'a dynamically-chosen random seed based on the timestamp, which '
          'sacrifices reproducibility in order to avoid Apache Beam stragglers '
          'that can occur when a given random seed leads to a problematic '
          'grammar/context.'))


@flags.multi_flags_validator(
    ['benchmark_base_dir', 'benchmark_dir'],
    message='At least one of benchmark_base_dir or benchmark_dir needs to be '
    'defined.')
def CheckBenchmarkDir(flags_dict):
  return bool(flags_dict['benchmark_base_dir'] or flags_dict['benchmark_dir'])


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  flags.mark_flags_as_required(['dataset_spec'])

  dataset_spec_id = _DATASET_SPEC.value
  replica_index = _REPLICA_INDEX.value
  benchmark_dir = (
      _BENCHMARK_DIR.value or directory_names.get_default_benchmark_directory(
          _BENCHMARK_SUITE_DIR.value, dataset_spec_id, replica_index))
  # Benchmark generation can optionally be parallelized for faster run time by
  # specifying here an appropriate Apache Beam runner.
  beam_runner = None
  benchmark_generation.generate_benchmark(
      benchmark_dir=benchmark_dir,
      should_overwrite=_OVERWRITE.value,
      dataset_spec_id=dataset_spec_id,
      replica_index=replica_index,
      beam_runner=beam_runner,
      deterministic=_DETERMINISTIC.value,
      enable_remote_dependencies=True)


if __name__ == '__main__':
  app.run(main)
