# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for generating the Conceptual SCAN benchmark.

"""

import collections
import copy
import dataclasses
import functools
import itertools
import json
import logging
import time
from typing import (Any, Callable, Dict, Hashable, Iterable, List, Mapping,
                    Optional, Sequence, Set, Tuple, TypeVar, Union)

import apache_beam as beam
from language.compgen.nqg.tasks import mcd_utils
import numpy as np
import tensorflow_datasets as tfds

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_generation
from conceptual_learning.cscan import dataset_io
from conceptual_learning.cscan import dataset_spec_loader
from conceptual_learning.cscan import divergence_maximization
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import grammar_generation
from conceptual_learning.cscan import grammar_loader
from conceptual_learning.cscan import grammar_representation
from conceptual_learning.cscan import grammar_schema as gs
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import production_composition
from conceptual_learning.cscan import production_trees
from conceptual_learning.cscan import rule_conversion
from conceptual_learning.cscan import sampling
from conceptual_learning.cscan import similarity_metadata

DatasetAndCounters = Tuple[cl.GroupedExampleSet, outputs.GenerationCounters]
Splits = Dict[tfds.Split, cl.GroupedExampleSet]
SplitsAndSplittingStats = Tuple[Splits, outputs.SplittingStats]
# Tuple (example_group, dataset_spec, counters, num_examples_by_type).
ExampleGroupWithSpecAndCounters = Tuple[cl.ExampleGroup, inputs.DatasetSpec,
                                        outputs.GenerationCounters,
                                        Mapping[cl.RequestType, int]]


def _generate_dataset_and_counters(
    original_dataset_spec,
    enable_remote_dependencies):
  """Generates a dataset in accordance with given dataset spec."""
  dataset_spec = copy.deepcopy(original_dataset_spec)
  logging.info('Generating dataset with spec: %s', dataset_spec)
  rng = np.random.RandomState(
      dataset_spec.generation_options.random_seed_or_timestamp())
  counters = outputs.GenerationCounters()
  dataset = dataset_generation.generate_dataset(
      dataset_spec.generation_options,
      counters,
      rng,
      load_fixed_phrase_structure_grammar_template_for_dataset_spec(
          dataset_spec),
      enable_remote_dependencies=enable_remote_dependencies)
  logging.info('Completed dataset generation. Counters: %s', counters)
  return dataset, counters


def _combine_datasets_and_counters(
    datasets_and_counters):
  """Returns a concatenation of the datasets and a sum of the counters.

  Does not filter out duplicates. That means that the resulting dataset may
  potentially contain multiple different ExampleGroups with identical contexts.

  Args:
    datasets_and_counters: Datasets and counters to be combined.
  """
  dataset = cl.GroupedExampleSet()
  counters = outputs.GenerationCounters()
  for other_dataset, other_counters in datasets_and_counters:
    dataset.example_groups.extend(other_dataset.example_groups)
    counters += other_counters
  return dataset, counters


def _extract_outputs(output_pair):
  return [
      beam.pvalue.TaggedOutput('0', output_pair[0]),
      beam.pvalue.TaggedOutput('1', output_pair[1])
  ]


def generate_and_write_dataset_pipeline(
    benchmark_dir, dataset_spec,
    enable_remote_dependencies):
  """Returns the pipeline to generate and write the dataset on Apache Beam."""
  random_seed = dataset_spec.generation_options.random_seed
  rng = np.random.RandomState(random_seed)
  # Construct a separate dataset spec for each context (each with a new unique
  # random seed), so that the contexts can be generated in parallel.
  single_context_dataset_specs = []
  for _ in range(dataset_spec.generation_options.sampling.num_contexts):
    # RandomState seed can be 'any integer between 0 and 2**32 - 1 inclusive'.
    # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html
    if not dataset_spec.generation_options.use_timestamp_for_random_seed():
      random_seed = rng.randint(0, 2**32)
    single_context_dataset_specs.append(
        dataclasses.replace(
            dataset_spec,
            generation_options=dataclasses.replace(
                dataset_spec.generation_options,
                random_seed=random_seed,
                sampling=dataclasses.replace(
                    dataset_spec.generation_options.sampling, num_contexts=1))))

  def pipeline(root):
    dataset_and_counters = (
        root | beam.Create(single_context_dataset_specs)
        | 'GenerateDatasetForEachContext' >> beam.Map(
            functools.partial(
                _generate_dataset_and_counters,
                enable_remote_dependencies=enable_remote_dependencies))
        | 'CombineDatasetsAndCounters' >>
        beam.CombineGlobally(_combine_datasets_and_counters)
        | 'SeparateDatasetFromCounters' >>
        beam.FlatMap(_extract_outputs).with_outputs('0', '1'))
    dataset = dataset_and_counters['0']
    counters = dataset_and_counters['1']
    _ = (
        dataset
        | beam.Map(lambda x: json.dumps(copy.deepcopy(x).serialize(), indent=2))
        | 'WriteDataset' >> beam.io.WriteToText(
            dataset_io.get_dataset_path(benchmark_dir), shard_name_template=''))
    _ = (
        counters | beam.Map(lambda x: x.to_json())
        | 'WriteDatasetCounters' >> beam.io.WriteToText(
            dataset_io.get_dataset_counters_path(benchmark_dir),
            shard_name_template=''))

  return pipeline


def _split_dataset_and_compute_splitting_stats(
    splitting_option, benchmark_dir,
    rng):
  dataset = dataset_io.read_dataset(benchmark_dir)
  return split_dataset_and_compute_stats(dataset, splitting_option, rng)


def split_dataset_pipeline(
    dataset_spec,
    benchmark_dir):
  """Returns the pipeline to split dataset and write the splits and stats."""
  rng = np.random.RandomState(
      dataset_spec.generation_options.random_seed_or_timestamp())

  splitting_options = [dataset_spec.generation_options.splitting]
  split_fn = functools.partial(
      _split_dataset_and_compute_splitting_stats,
      benchmark_dir=benchmark_dir,
      rng=rng)

  class WriteSplits(beam.DoFn):

    def __init__(self, benchmark_dir):
      self.benchmark_dir = benchmark_dir

    def process(self, element):
      dataset_io.write_splits(self.benchmark_dir, element)

  class WriteSplittingStats(beam.DoFn):

    def __init__(self, benchmark_dir):
      self.benchmark_dir = benchmark_dir

    def process(self, element):
      dataset_io.write_splitting_stats(self.benchmark_dir, element)

  def pipeline(root):
    splits_and_splitting_stats = (
        root | beam.Create(splitting_options)
        | 'SplitDataset' >> beam.Map(split_fn)
        | 'SeparateSplitsFromStats' >>
        beam.FlatMap(_extract_outputs).with_outputs('0', '1'))

    splits = splits_and_splitting_stats['0']
    splitting_stats = splits_and_splitting_stats['1']

    _ = (splits | 'WriteSplits' >> beam.ParDo(WriteSplits(benchmark_dir)))
    _ = (
        splitting_stats |
        'WriteSplittingStats' >> beam.ParDo(WriteSplittingStats(benchmark_dir)))

  return pipeline


def _populate_a_single_example_group_with_more_examples(
    example_group_with_specs_and_counters,
    enable_remote_dependencies = False,
    allow_input_mutation = True):
  # pyformat: disable
  """Populates a single example group with more examples.

  Args:
    example_group_with_specs_and_counters: Tuple containing the example group
      and the necessary arguments to generate more examples for the example
      group. The tuple consist of (example_group, dataset_spec,
      num_examples_by_type) where
        - example_group: the example group we want to append new examples to.
          If allow_input_mutation is False then a copy of example_group will be
          populated with the new examples and wrapped in a cl.GroupedExampleSet
          object before returning it.
        - dataset_spec: Dataset spec to use for generating more examples.
        - counters: Generation counters for the example group.
        - num_examples_by_type: Dictionary representing the number of examples
          per each request type.
    enable_remote_dependencies: Whether to allow the generation process to
      access non-local resources. Set this to True when performing unit tests.
    allow_input_mutation: Allow mutation of the input. If this function is being
      run within a beam pipeline then we can't modify the input by adding more
      examples to the ExampleGroup object.

  Returns:
    A dataset with one example group after populating example_group with more
    examples.
  """
  # pyformat: enable
  example_group, dataset_spec, counters, num_examples_by_type = (
      example_group_with_specs_and_counters)
  if not allow_input_mutation:
    # Only modify a copy of the input.
    example_group = copy.deepcopy(example_group)
    dataset_spec = copy.deepcopy(dataset_spec)
    counters = copy.deepcopy(counters)
    num_examples_by_type = copy.deepcopy(num_examples_by_type)
  rng = np.random.RandomState(
      dataset_spec.generation_options.random_seed_or_timestamp())
  # Reconstruct the inference engine.
  if (dataset_spec.generation_options.sampling.rule_format
     ) == enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION:
    # When using feature grammar rule format, pass through rules are provided
    # explicitly by the grammar and can even be selected to be unreliable rules
    # and therefore we shouldn't provide it explicitly in this case.
    pass_through_rules = []
  else:
    # Load the pass through rules from the dataset spec.
    phrase_structure_grammar = (
        load_fixed_phrase_structure_grammar_template_for_dataset_spec(
            dataset_spec))
    pass_through_rules = []
    if phrase_structure_grammar is not None:
      for rule in phrase_structure_grammar.get_all_rules():
        rule_string = rule.to_rule_string()
        rule_production = nltk_utils.production_from_production_string(
            rule_string)
        if nltk_utils.is_pass_through_rule(rule_production):
          pass_through_rules.append(rule_production)
  # Create a grammar generator for the simple example generator.
  grammar_generator = grammar_generation.GrammarGenerator(
      options=dataset_spec.generation_options.grammar, rng=rng)
  simple_example_generator = sampling.ExampleGenerator(
      grammar=example_group.context.metadata.grammar,
      rng=rng,
      grammar_generator=grammar_generator,
      options=dataset_spec.generation_options.sampling)
  inference_engine = inference.InferenceEngine(
      provenance_by_production=simple_example_generator.provenance_by_production
  )
  # 1) Add the pass through rules to the inference engine as monotonic examples.
  for pass_through_rule in pass_through_rules:
    inference_engine.add_production(pass_through_rule, is_monotonic=True)

  # 2) Add explicit rules to the inference engine as monotonic examples.
  for rule in example_group.context.metadata.explicit_rules:
    for production in simple_example_generator.get_productions_from_rule(rule):
      inference_engine.add_production(production, is_monotonic=True)
  # 3) Add hidden rules as defeasible examples.
  for rule in example_group.context.metadata.hidden_rules:
    for production in simple_example_generator.get_productions_from_rule(rule):
      inference_engine.add_production(production)
  # 4) Add context examples to the inference engine.
  for example in example_group.context:
    if example.metadata.production not in inference_engine.all_productions:
      # If the production is not in the inference engine then that means
      # one of its source rules is sampled from an alternative grammar.
      # Here we simply rebuild that production to create the necessary
      # provenance_by_production mappings.
      example_provenance = example.metadata.production_provenance
      production = example_provenance.source
      for other_production, idx in example_provenance.compositions:
        production = production_composition.compose(
            production,
            other_production,
            idx,
            provenance_by_production=(
                simple_example_generator.provenance_by_production))

    inference_engine.add_production(
        example.metadata.production,
        is_monotonic=example.qualifier == cl.Qualifier.M)
  if example_group.context.metadata.hidden_unknown_rules:
    # Adjust the inference engine to only induce hidden true rules and ignore
    # productions with hidden unknown source rules.
    # Note that we didn't do this from the beginning in order to allow
    # provenance_by_production to include the provenance of all productions.
    inference_engine_copy = inference_engine.copy_monotonic_engine()
    for production in inference_engine.source_productions:
      rule = rule_conversion.rule_from_production(
          production, example_group.context.metadata.rule_format)
      if rule in example_group.context.metadata.hidden_true_rules:
        inference_engine_copy.add_production(production, is_monotonic=False)
    inference_engine = inference_engine_copy

  dataset_generation.populate_example_group_with_top_level_examples(
      example_group,
      num_examples_by_type=num_examples_by_type,
      simple_example_generator=simple_example_generator,
      options=dataset_spec.generation_options.sampling,
      inference_engine=inference_engine,
      counters=counters,
      rng=rng,
      enable_remote_dependencies=enable_remote_dependencies)
  return cl.GroupedExampleSet(example_groups=[example_group]), counters


def populate_dataset_with_more_examples_per_context_pipeline(
    benchmark_dir,
    dataset,
    split,
    original_dataset_spec,
    enable_remote_dependencies = True
):
  """Returns the pipeline to generate more examples on Apache Beam."""
  dataset_spec = copy.deepcopy(original_dataset_spec)
  random_seed = dataset_spec.generation_options.random_seed
  rng = np.random.RandomState(random_seed)
  # Construct a separate dataset spec for each context (each with a new unique
  # random seed), so that the contexts can be generated in parallel.
  example_groups_with_specs_and_counters_list: List[
      ExampleGroupWithSpecAndCounters] = []
  # Create a new examples_by_request_type schedule with the target requests per
  # context set to additional_test_and_validation_requests_per_context.
  target_num_requests_per_context = (
      dataset_spec.generation_options.sampling
      .additional_test_and_validation_requests_per_context)
  num_examples_by_type_schedule = (
      dataset_spec.generation_options.sampling
      .calculate_schedule_of_examples_by_type(target_num_requests_per_context))
  for example_group, num_examples_by_type in zip(dataset.example_groups,
                                                 num_examples_by_type_schedule):
    # RandomState seed can be 'any integer between 0 and 2**32 - 1 inclusive'.
    # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html
    if not dataset_spec.generation_options.use_timestamp_for_random_seed():
      random_seed = rng.randint(0, 2**32)
      dataset_spec = dataclasses.replace(
          dataset_spec,
          generation_options=dataclasses.replace(
              dataset_spec.generation_options,
              random_seed=random_seed,
              sampling=dataclasses.replace(
                  dataset_spec.generation_options.sampling, num_contexts=1)))
    # Re-create the counters object from the example group.
    counters = dataset_io.counters_from_grouped_example_set(
        cl.GroupedExampleSet([example_group]))
    example_group_with_spec_and_counters = (example_group, dataset_spec,
                                            counters, num_examples_by_type)
    example_groups_with_specs_and_counters_list.append(
        example_group_with_spec_and_counters)

  def pipeline(root):

    dataset_and_counters = (
        root | beam.Create(example_groups_with_specs_and_counters_list)
        | 'PopulateGroupsWithMoreExamples' >> beam.Map(
            functools.partial(
                _populate_a_single_example_group_with_more_examples,
                enable_remote_dependencies=enable_remote_dependencies,
                allow_input_mutation=False))
        | 'CombineDatasetsAndCounters' >>
        beam.CombineGlobally(_combine_datasets_and_counters)
        | 'SeparateDatasetFromCounters' >>
        beam.FlatMap(_extract_outputs).with_outputs('0', '1'))
    dataset = dataset_and_counters['0']
    counters = dataset_and_counters['1']
    _ = (
        dataset
        | beam.Map(lambda x: json.dumps(copy.deepcopy(x).serialize(), indent=2))
        | 'WriteDataset' >> beam.io.WriteToText(
            dataset_io.get_split_path(benchmark_dir, split),
            shard_name_template=''))
    _ = (
        counters | beam.Map(lambda x: x.to_json())
        | 'WriteDatasetCounters' >> beam.io.WriteToText(
            dataset_io.get_split_counters_path(benchmark_dir, split),
            shard_name_template=''))


  return pipeline


def _get_split_indices(second_split_fraction, rng,
                       labels):
  """Returns the indices splitting a sequence into two portions.

  Args:
    second_split_fraction: The fraction of items for the second_split.
    rng: Random number generator.
    labels: The labels will be roughly equally represented in the returned
      splits.
  """

  def get_first_split_size(length):
    # If there is only one item, we make sure it goes to first_split.
    if length > 1:
      second_split_size = int(length * second_split_fraction)
    else:
      second_split_size = 0
    first_split_size = length - second_split_size
    return first_split_size

  # We stratify the indices according to the labels and then split each stratum
  # in a way similar to scikit-learn's StratifiedShuffleSplit.
  # We choose not to use StratifiedShuffleSplit directly here since it does not
  # handle some edge cases well (e.g. when the sequence has only a few items and
  # the requested train_size and test_size would cause one of the splits to be
  # empty).
  indices_by_label = {}
  for i, label in enumerate(labels):
    indices_by_label.setdefault(label, []).append(i)

  first_split_indices = []
  second_split_indices = []
  for label, indices_of_label in indices_by_label.items():
    first_split_size = get_first_split_size(len(indices_of_label))
    rng.shuffle(indices_of_label)
    first_split_indices.extend(indices_of_label[:first_split_size])
    second_split_indices.extend((indices_of_label[first_split_size:]))

  # If there are more than one distinct label, there is no guarantee that at
  # this point the size of the splits would be even close to expected, so we try
  # to adjust it.
  rng.shuffle(first_split_indices)
  rng.shuffle(second_split_indices)
  all_indices = first_split_indices + second_split_indices
  first_split_size = get_first_split_size(len(all_indices))

  return all_indices[:first_split_size], all_indices[first_split_size:]


def _split_sequence(
    sequence,
    test_fraction,
    validation_fraction,
    rng,
    labels = None):
  """Returns the mapping of split ID to lists of items.

  Args:
    sequence: The sequence of items to split.
    test_fraction: The fraction of items for the test split.
    validation_fraction: The fraction of items for the validation split.
    rng: Random number generator.
    labels: If provided, must be of the same length as the sequence.  The labels
      will be roughly equally represented in the returned splits.
  """
  if labels is None:
    labels = [0 for _ in range(len(sequence))]
  # We will first split into (Train, ValidationAndTest), then further split into
  # (Train, Validation, Test).
  train_indices, validation_and_test_indices = _get_split_indices(
      second_split_fraction=test_fraction + validation_fraction,
      labels=labels,
      rng=rng)

  validation_and_test_labels = [
      labels[index] for index in validation_and_test_indices
  ]
  if test_fraction == 0:
    second_split_fraction = 0.0
  else:
    second_split_fraction = (
        test_fraction / (test_fraction + validation_fraction))
  first_split_indices, second_split_indices = _get_split_indices(
      second_split_fraction=second_split_fraction,
      labels=validation_and_test_labels,
      rng=rng)
  validation_indices = [
      validation_and_test_indices[k] for k in first_split_indices
  ]
  test_indices = [validation_and_test_indices[k] for k in second_split_indices]

  train_items = [sequence[k] for k in train_indices]
  validation_items = [sequence[k] for k in validation_indices]
  test_items = [sequence[k] for k in test_indices]

  sequence_by_split = {
      tfds.Split.TRAIN: train_items,
      tfds.Split.VALIDATION: validation_items,
      tfds.Split.TEST: test_items
  }

  return sequence_by_split


def _split_dataset_by_example(dataset,
                              test_fraction, validation_fraction,
                              rng):
  """Splits the dataset by example."""

  example_set = dataset.to_example_set()
  # We want to make sure the reliable and unreliable examples are roughly
  # equally represented in the splits.
  labels = [example.is_unreliable for example in example_set]
  examples_by_split = _split_sequence(
      list(example_set), test_fraction, validation_fraction, rng, labels)

  grouped_example_set_by_split = {}
  for split, examples_of_split in examples_by_split.items():
    example_set = cl.ExampleSet.from_examples(examples_of_split)
    grouped_example_set_by_split[split] = cl.GroupedExampleSet.from_example_set(
        example_set)

  return grouped_example_set_by_split


def _split_dataset_by_context(dataset,
                              test_fraction, validation_fraction,
                              rng):
  """Splits the dataset by context.

  The current implementation equally represents contexts with and without
  omitted rules in the splits to avoid large biases of the number of unknown
  replies between train and test splits.

  Args:
    dataset: The dataset to split.
    test_fraction: The fraction of items for the test split.
    validation_fraction: The fraction of items for the validation split.
    rng: Random number generator.

  Returns:
    Mapping of split ID to the split contents.
  """
  # We want to make sure the contexts with/without omitted rules and unreliable
  # rules are roughly equally represented in the splits.
  labels = []
  for example_group in dataset.example_groups:
    label = (bool(example_group.context.metadata.omitted_rules),
             bool(example_group.context.metadata.unreliable_rules))
    labels.append(label)

  example_groups_by_split = _split_sequence(dataset.example_groups,
                                            test_fraction, validation_fraction,
                                            rng, labels)
  grouped_example_set_by_split = {
      split: cl.GroupedExampleSet(example_groups=example_group_of_split)
      for split, example_group_of_split in example_groups_by_split.items()
  }

  return grouped_example_set_by_split


def _split_dataset_by_context_and_output_pattern(
    dataset, test_fraction,
    validation_fraction, rng):
  """Splits the dataset by context and output pattern.

  The current implementation splits all the output patterns in top-level
  examples equally into train/validation/test patterns, then filters the dataset
  splits by these allowed patterns in each split.

  Args:
    dataset: The dataset to split.
    test_fraction: The fraction of items for the test split.
    validation_fraction: The fraction of items for the validation split.
    rng: Random number generator.

  Returns:
    Mapping of split ID to the split contents.
  """
  context_split = _split_dataset_by_context(dataset, test_fraction,
                                            validation_fraction, rng)

  output_patterns = set()
  for example in dataset.to_flat_examples():
    output_patterns.add(example.get_output_pattern())

  # We split the output patterns roughly equally into three disjoint parts so
  # that the train/validation/test fractions are close to the requested
  # fractions. (Note that since the initial context and output pattern splits
  # are uniform and independent, we would not want to split both of them based
  # on the requested validation and test fractions, as that would lead to the
  # split sizes being proportional to something like requested_fraction^2).
  allowed_output_patterns_by_split = _split_sequence(
      list(output_patterns), 1.0 / 3, 1.0 / 3, rng)

  grouped_example_set_by_split = {}
  for split, grouped_example_set_for_split in context_split.items():
    allowed_output_patterns_for_split = (
        set(allowed_output_patterns_by_split[split]))
    filtered_examples = [
        example for example in grouped_example_set_for_split.to_flat_examples()
        if example.get_output_pattern() in allowed_output_patterns_for_split
    ]
    grouped_example_set_by_split[split] = cl.GroupedExampleSet.from_example_set(
        cl.ExampleSet.from_examples(filtered_examples))

  return grouped_example_set_by_split


def _split_dataset_by_context_and_output_pattern_and_reshuffle(
    dataset, test_fraction,
    validation_fraction, rng):
  """Splits the dataset by context and output pattern, then resplit by context.

  This is to remove the effect of splitting by output patterns to understand
  its effect on the learner's performance.

  Args:
    dataset: The dataset to split.
    test_fraction: The fraction of items for the test split.
    validation_fraction: The fraction of items for the validation split.
    rng: Random number generator.

  Returns:
    Mapping of split ID to the split contents.
  """
  output_pattern_split = _split_dataset_by_context_and_output_pattern(
      dataset, test_fraction, validation_fraction, rng)

  # We gather the example groups into a single GroupedExampleSet, and then
  # resplit by context.
  example_groups = []
  for grouped_example_set in output_pattern_split.values():
    example_groups.extend(grouped_example_set.example_groups)
  dataset = cl.GroupedExampleSet(example_groups=example_groups)

  example_groups_by_split = _split_sequence(dataset.example_groups,
                                            test_fraction, validation_fraction,
                                            rng)
  grouped_example_set_by_split = {
      split: cl.GroupedExampleSet(example_groups=example_group_of_split)
      for split, example_group_of_split in example_groups_by_split.items()
  }

  return grouped_example_set_by_split


def _split_dataset_by_subsample_and_context(dataset,
                                            test_fraction,
                                            validation_fraction,
                                            rng,
                                            subsample_size):
  """Subsamples the top-level examples uniformly, then splits by context.

  Args:
    dataset: The dataset to split.
    test_fraction: The fraction of items for the test split.
    validation_fraction: The fraction of items for the validation split.
    rng: Random number generator.
    subsample_size: The total number of top-level examples to keep.

  Returns:
    Mapping of split ID to the split contents.
  """
  all_examples = list(dataset.to_flat_examples())
  rng.shuffle(all_examples)
  sampled_examples = all_examples[:subsample_size]
  dataset = cl.GroupedExampleSet.from_example_set(
      cl.ExampleSet.from_examples(sampled_examples))

  return _split_dataset_by_context(dataset, test_fraction, validation_fraction,
                                   rng)


def _split_dataset_by_sorting(dataset,
                              test_fraction,
                              validation_fraction,
                              rng,
                              key,
                              reverse = False):
  """Splits the dataset by sorting the example groups.

  The example groups are sorted using the key, then assigned to the train,
  validation, and test splits in order.

  Args:
    dataset: The dataset to split.
    test_fraction: The fraction of items for the test split.
    validation_fraction: The fraction of items for the validation split.
    rng: Random number generator.
    key: The function to extract a comparison key from each ExampleGroup.
    reverse: If True, sort the contexts in decreasing order.

  Returns:
    Mapping of split ID to the split contents.
  """
  del rng
  train_fraction = 1.0 - test_fraction - validation_fraction
  train_size = int(train_fraction * len(dataset.example_groups))
  validation_size = int(validation_fraction * len(dataset.example_groups))

  sorted_example_groups = sorted(
      dataset.example_groups, key=key, reverse=reverse)

  train_items = sorted_example_groups[:train_size]
  validation_items = sorted_example_groups[train_size:train_size +
                                           validation_size]
  test_items = sorted_example_groups[train_size + validation_size:]

  example_groups_by_split = {
      tfds.Split.TRAIN: train_items,
      tfds.Split.VALIDATION: validation_items,
      tfds.Split.TEST: test_items
  }

  grouped_example_set_by_split = {
      split: cl.GroupedExampleSet(example_groups=example_group_of_split)
      for split, example_group_of_split in example_groups_by_split.items()
  }

  return grouped_example_set_by_split


def _tuple_compounds_from_atoms(
    atoms,
    compound_size,
    compound_transformation = None
):
  """Returns the set of tuple compounds from atoms.

  A tuple compound is a sorted tuple of atoms.  For example, given the list
  ['a', 'b', 'c'] of atoms, the tuple 2-compounds are ('a', 'b'), ('a', 'c') and
  ('b', 'c').

  Args:
    atoms: A sorted list of strings, the list of atoms.
    compound_size: The size of each tuple compound.
    compound_transformation: If provided, is called on each atom before forming
      compounds.
  """
  if compound_transformation is None:
    return set(itertools.combinations(atoms, compound_size))
  else:
    transformed_rule_strings = sorted(set(map(compound_transformation, atoms)))
    return set(itertools.combinations(transformed_rule_strings, compound_size))


@functools.lru_cache(maxsize=None)
def _composition_compounds_from_example(
    example, compound_size):
  """Returns the set of composition compounds."""
  production_tree = production_trees.ProductionTree.from_production_provenance(
      example.metadata.production_provenance)
  subtrees = production_tree.get_all_subtrees(size=compound_size)

  # Compounds are defined to be tuples of strings.
  return set((subtree.get_input_string(),) for subtree in subtrees)


def _maximize_context_compound_divergence_with_swap(
    split, test_fraction, validation_fraction,
    rng, get_atoms_fn,
    get_compounds_fn,
    options):
  """Returns a new split of the example groups with maximized divergence."""

  def maximize_divergence_fn(examples_1, examples_2):
    return mcd_utils.swap_examples(
        examples_1,
        examples_2,
        get_compounds_fn,
        get_atoms_fn,
        options.max_iterations,
        options.max_divergence,
        options.min_atom_count,
        print_frequencies=False)

  # We initially merge the train split and validation split to make sure the
  # test split has large divergence from train+validation.
  train_and_validation_example_groups = (
      split[tfds.Split.TRAIN].example_groups +
      split[tfds.Split.VALIDATION].example_groups)
  test_example_groups = split[tfds.Split.TEST].example_groups

  updated_train_and_validation_example_groups, updated_test_example_groups = (
      maximize_divergence_fn(train_and_validation_example_groups,
                             test_example_groups))

  # We resplit the train and validation example groups also by compound
  # divergence.
  train_validation_split = _split_sequence(
      updated_train_and_validation_example_groups,
      test_fraction=0,
      validation_fraction=(validation_fraction / (1.0 - test_fraction)),
      rng=rng)

  train_example_groups = train_validation_split[tfds.Split.TRAIN]
  validation_example_groups = train_validation_split[tfds.Split.VALIDATION]

  updated_train_example_groups, updated_validation_example_groups = (
      maximize_divergence_fn(train_example_groups, validation_example_groups))

  grouped_example_set_by_split = {
      tfds.Split.TRAIN:
          cl.GroupedExampleSet(example_groups=updated_train_example_groups),
      tfds.Split.VALIDATION:
          cl.GroupedExampleSet(example_groups=updated_validation_example_groups
                              ),
      tfds.Split.TEST:
          cl.GroupedExampleSet(example_groups=updated_test_example_groups)
  }
  return grouped_example_set_by_split


def _get_sizes_by_stages(
    num_items, output_fraction, validation_fraction,
    test_fraction,
    two_stage_mcd):
  """Returns the sizes for the two-stage splitting algorithm.

  Each stage of splitting returns two subsets of items of the input items.  The
  returned tuples are output sizes:
    ((size of items1 in stage1, size of items2 in stage1),
     (size of items1 in stage2, size of items2 in stage2))

  Depending on the two_stage_mcd argument, the output sizes are interpreted as:

    - If two_stage_mcd is False:
      ((size of train+validation, size of test),
       (size of train, size of validation))

    - If two_stage_mcd is True:
      ((size of train+validation+reserve, size of test),
       (size of train, size of validation))

      Here "reserve" is half of the num_items * (1.0 - output_fraction) items to
      be discarded in the whole process.

  Args:
    num_items: The total number of items that will be split into train,
      validation, and test split.
    output_fraction: The fraction of num_items that will be the sum of split
      sizes.
    validation_fraction: The requested fraction for the validation split.  The
      validation size will be num_items * output_fraction * validation_fraction.
    test_fraction: The requested fraction for the test split.  The test size
      will be num_items * output_fraction * test_fraction.
    two_stage_mcd: Whether or not the compound divergence maximization algorithm
      will be applied for both stages of splitting. If True, half of the
      num_items * (1.0 - output_fraction) items to be discarded in the whole
      process will be discarded in each of the two stages.
  """
  # These are requested fractions.  Since the splitting algorithm discards some
  # example groups in order to achieve higher compound divergence, we should
  # think of train/validation/test_fractions as relative fractions.
  train_fraction = 1.0 - validation_fraction - test_fraction

  # The absolute fractions are the actual fractions of all available example
  # groups to be assigned to each split.  For example, the desired final size of
  # the train split should be equal to:
  # absolute_train_fraction * len(dataset.example_groups).
  absolute_train_fraction = output_fraction * train_fraction
  absolute_validation_fraction = output_fraction * validation_fraction
  absolute_test_fraction = output_fraction * test_fraction

  # We compute an additional budget of number of items to keep in items_1 during
  # the first stage of splitting, to be discarded during the second stage of
  # splitting.
  if two_stage_mcd:
    num_items_to_discard = int(num_items * (1.0 - output_fraction))
    num_items_to_discard_when_splitting_validation = (
        int(num_items_to_discard / 2))
  else:
    num_items_to_discard_when_splitting_validation = 0

  # First stage: split into train+validation and test.  Here we request size_2
  # to be the desired final test split size, but for size_1 we allow more than
  # the desired final size of train+validation, so that in the next stage of
  # splitting the algorithm has room to discard some items.
  stage_1_size_1 = (
      int((absolute_train_fraction + absolute_validation_fraction) * num_items)
      + num_items_to_discard_when_splitting_validation)
  stage_1_size_2 = int(absolute_test_fraction * num_items)

  # Second stage: split into train and validation.  We request the desired
  # final sizes of the splits.
  stage_2_size_1 = int(absolute_train_fraction * num_items)
  stage_2_size_2 = int(absolute_validation_fraction * num_items)

  return (stage_1_size_1, stage_1_size_2), (stage_2_size_1, stage_2_size_2)


_T = TypeVar('_T')


def _split_train_and_validation_maintaining_atom_coverage(
    items, validation_size,
    get_atoms_fn,
    rng):
  """Returns train and validation splits with requested validation_size.

  This function performs a slightly constrained random split of the items into
  train and validation splits in a way so that the train split's atom coverage
  is as high as possible.

  This is to be used only for stage 2 (splitting validation from
  train+validation) of the splitting when SplitOptions.two_stage_mcd is False.

  Args:
    items: The items to split into train and validation.
    validation_size: The size of the validation split.
    get_atoms_fn: The function calculating the set of atoms for every item.
    rng: A random number generator.
  """
  # First we make sure every atom appears in train at least once.
  train_items = []
  item_indices_by_atom = {}
  for index, item in enumerate(items):
    for atom in get_atoms_fn(item):
      item_indices_by_atom.setdefault(atom, []).append(index)

  # Use a sorted list to ensure reproducibility.
  atoms_to_check = sorted(item_indices_by_atom)
  assigned_indices = set()
  while atoms_to_check:
    atom = atoms_to_check.pop()
    item_index = rng.choice(item_indices_by_atom[atom])
    assigned_indices.add(item_index)
    train_items.append(items[item_index])

    # The item is sampled in order to add a targeted atom to the train split,
    # but adding the item could add other atoms too, which we do not have to
    # target again.
    for atom in get_atoms_fn(items[item_index]):
      try:
        atoms_to_check.remove(atom)
      except ValueError:
        # Ignore if atom is not in atoms_to_check.
        pass

  remaining_items = [
      item for index, item in enumerate(items) if index not in assigned_indices
  ]

  if len(remaining_items) <= validation_size:
    validation_items = remaining_items

    # We make sure the requested validation_size is satisfied by moving items
    # from train to validation.
    train_indices = list(range(len(train_items)))
    rng.shuffle(train_indices)
    indices_to_move = train_indices[:(validation_size - len(validation_items))]

    validation_items.extend(train_items[index] for index in indices_to_move)
    train_items = [
        item for index, item in enumerate(train_items)
        if index not in indices_to_move
    ]
  else:
    indices = list(range(len(remaining_items)))
    rng.shuffle(indices)
    train_items.extend(
        remaining_items[index] for index in indices[:-validation_size])
    validation_items = [
        remaining_items[index] for index in indices[-validation_size:]
    ]

  return train_items, validation_items


def _maximize_context_compound_divergence_with_insertion_deletion(
    dataset, test_fraction,
    validation_fraction, rng,
    get_atoms_fn,
    get_compounds_fn,
    options):
  """Returns a new split of the example groups with maximized divergence."""
  num_items = len(dataset.example_groups)
  (stage_1_size_1, stage_1_size_2), (stage_2_size_1, stage_2_size_2) = (
      _get_sizes_by_stages(num_items, options.output_fraction,
                           validation_fraction, test_fraction,
                           options.two_stage_mcd))

  train_and_validation_example_groups, test_example_groups = (
      divergence_maximization.maximize_divergence(dataset.example_groups,
                                                  stage_1_size_1,
                                                  stage_1_size_2,
                                                  get_compounds_fn,
                                                  get_atoms_fn, options, rng))

  if options.two_stage_mcd:
    train_example_groups, validation_example_groups = (
        divergence_maximization.maximize_divergence(
            train_and_validation_example_groups, stage_2_size_1, stage_2_size_2,
            get_compounds_fn, get_atoms_fn, options, rng))
  else:
    train_example_groups, validation_example_groups = (
        _split_train_and_validation_maintaining_atom_coverage(
            train_and_validation_example_groups, stage_2_size_2, get_atoms_fn,
            rng))

  grouped_example_set_by_split = {
      tfds.Split.TRAIN:
          cl.GroupedExampleSet(example_groups=train_example_groups),
      tfds.Split.VALIDATION:
          cl.GroupedExampleSet(example_groups=validation_example_groups),
      tfds.Split.TEST:
          cl.GroupedExampleSet(example_groups=test_example_groups)
  }
  return grouped_example_set_by_split


def _maximize_top_level_example_compound_divergence(
    dataset, test_fraction,
    validation_fraction, rng,
    get_atoms_fn,
    get_compounds_fn,
    options):
  """Returns a new split of the example groups with maximized divergence."""
  examples = []
  for example_group in dataset.example_groups:
    examples.extend(example_group.to_flat_examples())
  num_items = len(examples)

  (stage_1_size_1, stage_1_size_2), (stage_2_size_1, stage_2_size_2) = (
      _get_sizes_by_stages(num_items, options.output_fraction,
                           validation_fraction, test_fraction,
                           options.two_stage_mcd))

  train_and_validation_examples, test_examples = (
      divergence_maximization.maximize_divergence(examples, stage_1_size_1,
                                                  stage_1_size_2,
                                                  get_compounds_fn,
                                                  get_atoms_fn, options, rng))

  if options.two_stage_mcd:
    train_examples, validation_examples = (
        divergence_maximization.maximize_divergence(
            train_and_validation_examples, stage_2_size_1, stage_2_size_2,
            get_compounds_fn, get_atoms_fn, options, rng))
  else:
    train_examples, validation_examples = (
        _split_train_and_validation_maintaining_atom_coverage(
            train_and_validation_examples, stage_2_size_2, get_atoms_fn, rng))

  if options.filter_contexts:
    # At this point examples sharing the same context are distributed across the
    # splits.  In order to avoid showing test contexts during training, we
    # filter the examples with frequency-based heuristic:
    # Among all the contexts that appear in test examples, we assign the most
    # frequent one third to the test split.  (We use one third here instead of
    # test_fraction to avoid the test split end up having too few top-level
    # examples.  A side effect of this choice is that the distribution of
    # contexts is not the requested train/validation/test fractions.)
    # Similarly, among the unassigned contexts, we assign the most frequent one
    # third to the validation split.  All the remaining unassigned contexts are
    # assigned to the train split.
    num_test_contexts = round(len(dataset.example_groups) / 3)
    test_context_counter = collections.Counter(
        example.context for example in test_examples)
    test_contexts = set(
        context
        for context, _ in test_context_counter.most_common(num_test_contexts))

    num_validation_contexts = round(len(dataset.example_groups) / 3)
    validation_context_counter = collections.Counter(
        example.context
        for example in validation_examples
        if example.context not in test_contexts)
    validation_contexts = set(
        context for context, _ in validation_context_counter.most_common(
            num_validation_contexts))

    validation_and_test_contexts = validation_contexts | test_contexts
    train_contexts = set(
        example_group.context
        for example_group in dataset.example_groups
        if example_group.context not in validation_and_test_contexts)

    train_examples = (
        example for example in train_examples
        if example.context in train_contexts)
    validation_examples = (
        example for example in validation_examples
        if example.context in validation_contexts)
    test_examples = (
        example for example in test_examples
        if example.context in test_contexts)

  grouped_example_set_by_split = {
      tfds.Split.TRAIN:
          cl.GroupedExampleSet.from_example_set(
              cl.ExampleSet.from_examples(train_examples)),
      tfds.Split.VALIDATION:
          cl.GroupedExampleSet.from_example_set(
              cl.ExampleSet.from_examples(validation_examples)),
      tfds.Split.TEST:
          cl.GroupedExampleSet.from_example_set(
              cl.ExampleSet.from_examples(test_examples))
  }
  return grouped_example_set_by_split


def _compute_compound_divergence_splitting_stats(
    splits, get_atoms_fn,
    get_compounds_fn,
    options
):
  """Computes splitting_stats for compound divergence split."""
  atom_coef = 0.5
  compound_coef = 0.1

  split_keys = [tfds.Split.TRAIN, tfds.Split.VALIDATION, tfds.Split.TEST]

  # The granularity level at which we compare the splits depends on how the
  # dataset was split.
  items_by_split_key = {}
  for split_key in split_keys:
    if options.top_level_example:
      items = []
      for example_group in splits[split_key].example_groups:
        items.extend(example_group.to_flat_examples())
      items_by_split_key[split_key] = items
    else:
      items_by_split_key[split_key] = splits[split_key].example_groups

  splitting_stats_of_algorithm = outputs.SplittingStatsOfAlgorithm()

  for split_key_pair in itertools.combinations(splits, 2):
    split_key_1, split_key_2 = split_key_pair

    # To make sure this can be serialized by dataclasses_json's to_json method,
    # we turn the pair into a string.
    split_key_pair = ':'.join(split_key_pair)
    items_1 = items_by_split_key[split_key_1]
    items_2 = items_by_split_key[split_key_2]

    splitting_stats_of_algorithm.atom_divergence[split_key_pair] = (
        mcd_utils.measure_example_divergence(items_1, items_2, get_atoms_fn,
                                             atom_coef))
    splitting_stats_of_algorithm.compound_divergence[split_key_pair] = (
        mcd_utils.measure_example_divergence(items_1, items_2, get_compounds_fn,
                                             compound_coef))

  all_atoms = set()
  all_compounds = set()
  for split_key in split_keys:
    items = items_by_split_key[split_key]
    atom_counter = collections.Counter()
    compound_counter = collections.Counter()
    for item in items:
      # Atoms and compounds are tuples of strings up to this point, which fail
      # to be serialized by dataclasses_json, so we make them strings.
      atom_strings = (str(atom) for atom in get_atoms_fn(item))
      compound_strings = (str(compound) for compound in get_compounds_fn(item))
      atom_counter.update(atom_strings)
      compound_counter.update(compound_strings)

    splitting_stats_of_algorithm.num_items_by_atom_by_split[split_key] = (
        dict(atom_counter))
    splitting_stats_of_algorithm.num_items_by_compound_by_split[split_key] = (
        dict(compound_counter))

    all_atoms.update(atom_counter.keys())
    all_compounds.update(compound_counter.keys())

  # Populate the coverage numbers only after all the atoms and compounds have
  # been accounted for.
  for split_key in split_keys:
    splitting_stats_of_algorithm.atom_coverage_by_split[split_key] = len(
        splitting_stats_of_algorithm.num_items_by_atom_by_split[split_key]
    ) / len(all_atoms)
    splitting_stats_of_algorithm.compound_coverage_by_split[split_key] = len(
        splitting_stats_of_algorithm.num_items_by_compound_by_split[split_key]
    ) / len(all_compounds)

  return splitting_stats_of_algorithm


def _get_compound_transformation(
    options
):
  if options.use_rule_pattern:
    return rule_conversion.rule_pattern_from_rule
  else:
    return None


def _split_dataset_by_context_compound_divergence(
    dataset, test_fraction,
    validation_fraction, rng,
    options,
    splitting_stats):
  """Returns the dataset split by context compound divergence."""

  if options.composition_compound:
    # Include only the atoms and compounds that actually show up in either a
    # context example or a top-level example in the example group.
    def _compounds_from_example_group(
        example_group,
        compound_size):
      compounds = set()
      for context_example in example_group.context:
        compounds.update(
            _composition_compounds_from_example(
                context_example, compound_size=compound_size))

      for example in example_group:
        compounds.update(
            _composition_compounds_from_example(
                example, compound_size=compound_size))
      return compounds

    def get_atoms_fn(example_group):
      return _compounds_from_example_group(example_group, 1)

    def get_compounds_fn(
        example_group):
      compounds = set()
      for compound_size in range(2, options.max_compound_size + 1):
        compounds.update(
            _compounds_from_example_group(example_group, compound_size))
      return compounds

  else:
    compound_transformation = _get_compound_transformation(options)

    # Since example groups are not hashable, for now we cache results of the
    # inner function.
    @functools.lru_cache(maxsize=None)
    def rule_strings_from_context(context):
      # We sort the rule strings to make sure that the compounds are uniquely
      # determined by their atoms and independent of the atoms' ordering.
      return sorted(context.metadata.examples_by_rule.keys())

    def get_atoms_fn(example_group):
      rule_strings = rule_strings_from_context(example_group.context)
      return _tuple_compounds_from_atoms(
          rule_strings,
          compound_size=1,
          compound_transformation=compound_transformation)

    def get_compounds_fn(
        example_group):
      rule_strings = rule_strings_from_context(example_group.context)
      compounds = set()
      for compound_size in range(2, options.max_compound_size + 1):
        compounds.update(
            _tuple_compounds_from_atoms(
                rule_strings,
                compound_size=compound_size,
                compound_transformation=compound_transformation))
      return compounds

  # We always make a context split in order to record the stats as the baseline.
  # It is also the starting point for the swap-based algorithm.
  context_split = _split_dataset_by_context(dataset, test_fraction,
                                            validation_fraction, rng)

  if options.use_insertion_deletion:
    # The insertion/deletion-based algorithm does not make use of the context
    # split.
    compound_divergence_split = (
        _maximize_context_compound_divergence_with_insertion_deletion(
            dataset, test_fraction, validation_fraction, rng, get_atoms_fn,
            get_compounds_fn, options))
  else:
    compound_divergence_split = (
        _maximize_context_compound_divergence_with_swap(
            context_split, test_fraction, validation_fraction, rng,
            get_atoms_fn, get_compounds_fn, options))

  splitting_stats.stats_by_algorithm[inputs.SplitBy.CONTEXT] = (
      _compute_compound_divergence_splitting_stats(context_split, get_atoms_fn,
                                                   get_compounds_fn, options))

  splitting_stats.stats_by_algorithm[inputs.SplitBy.COMPOUND_DIVERGENCE] = (
      _compute_compound_divergence_splitting_stats(compound_divergence_split,
                                                   get_atoms_fn,
                                                   get_compounds_fn, options))

  return compound_divergence_split


def _split_dataset_by_top_level_example_compound_divergence(
    dataset, test_fraction,
    validation_fraction, rng,
    options,
    splitting_stats):
  """Returns the dataset split by top-level example compound divergence."""
  if options.composition_compound:

    @functools.lru_cache(maxsize=None)
    def get_atoms_fn(example):
      return _composition_compounds_from_example(example, compound_size=1)

    @functools.lru_cache(maxsize=None)
    def get_compounds_fn(example):
      compounds = set()
      for compound_size in range(2, options.max_compound_size + 1):
        compounds.update(
            _composition_compounds_from_example(example, compound_size))
      return compounds

  else:
    compound_transformation = _get_compound_transformation(options)

    # In this case the get_atoms/compounds_fn used for the splitting and for the
    # stats calculation are different, since stats are calculated at the example
    # group level.
    def rule_strings_from_example(example):
      # We sort the rule strings to make sure that the compounds are uniquely
      # determined by their atoms and independent of the atoms' ordering.
      return sorted(example.metadata.rules)

    @functools.lru_cache(maxsize=None)
    def get_atoms_fn(example):
      rule_strings = rule_strings_from_example(example)
      return _tuple_compounds_from_atoms(
          rule_strings,
          compound_size=1,
          compound_transformation=compound_transformation)

    @functools.lru_cache(maxsize=None)
    def get_compounds_fn(example):
      rule_strings = rule_strings_from_example(example)
      compounds = set()
      for compound_size in range(2, options.max_compound_size + 1):
        compounds.update(
            _tuple_compounds_from_atoms(
                rule_strings,
                compound_size=compound_size,
                compound_transformation=compound_transformation))
      return compounds

  compound_divergence_split = (
      _maximize_top_level_example_compound_divergence(dataset, test_fraction,
                                                      validation_fraction, rng,
                                                      get_atoms_fn,
                                                      get_compounds_fn,
                                                      options))

  # We always make a context split in order to record the stats as the baseline.
  context_split = _split_dataset_by_context(dataset, test_fraction,
                                            validation_fraction, rng)

  splitting_stats.stats_by_algorithm[inputs.SplitBy.CONTEXT] = (
      _compute_compound_divergence_splitting_stats(context_split, get_atoms_fn,
                                                   get_compounds_fn, options))

  splitting_stats.stats_by_algorithm[inputs.SplitBy.COMPOUND_DIVERGENCE] = (
      _compute_compound_divergence_splitting_stats(compound_divergence_split,
                                                   get_atoms_fn,
                                                   get_compounds_fn, options))

  return compound_divergence_split


def _split_dataset_by_compound_divergence(
    dataset, test_fraction,
    validation_fraction, rng,
    options,
    splitting_stats):
  """Splits the dataset by context or top-level example compound divergence.

  The atoms are rule strings, and compounds are tuples of rule strings of sizes
  2, 3, ..., max_compound_size.

  If use_rule_pattern is True, then the rule strings in compounds (but not
  atoms) are first transformed into "rule patterns.  For example, the rule
  pattern of "[x1 twice] = [x1] [x1]" is "[x1 _] = [x1] [x1]", which is also the
  rule pattern of other rules such as "[x1 left] = [x1] [x1]".

  Args:
    dataset: The dataset to split.
    test_fraction: The fraction of items for the test split.
    validation_fraction: The fraction of items for the validation split.
    rng: Random number generator.
    options: Options controlling the compound divergence splitting algorithm.
    splitting_stats: Statistics about dataset splitting.

  Returns:
    Mapping of split ID to the split contents.
  """
  if options.top_level_example:
    return _split_dataset_by_top_level_example_compound_divergence(
        dataset, test_fraction, validation_fraction, rng, options,
        splitting_stats)
  else:
    return _split_dataset_by_context_compound_divergence(
        dataset, test_fraction, validation_fraction, rng, options,
        splitting_stats)


def split_dataset_and_compute_stats(
    dataset, options,
    rng):
  """Given a set of examples, splits them into train and test sets.

  Each splitting algorithm may optionally populate splitting stats.  Currently
  only compound divergence splitting algorithms output non-empty splitting
  stats.

  Args:
    dataset: Dataset of examples to split.
    options: Bundle of options controlling the splitting process.
    rng: Random number generator.

  Returns:
    Mapping of split ID to the split contents and splitting stats.
  """
  logging.info('Splitting dataset')
  splitting_stats = outputs.SplittingStats()

  if options.split_by == inputs.SplitBy.EXAMPLE:
    splits = _split_dataset_by_example(dataset, options.test_fraction,
                                       options.validation_fraction, rng)
  elif options.split_by == inputs.SplitBy.CONTEXT:
    splits = _split_dataset_by_context(dataset, options.test_fraction,
                                       options.validation_fraction, rng)
  elif options.split_by == inputs.SplitBy.LOW_EXPLICIT_FRACTION_IN_TRAIN:
    key = lambda example_group: example_group.context.explicit_fraction
    splits = _split_dataset_by_sorting(
        dataset,
        options.test_fraction,
        options.validation_fraction,
        rng,
        key,
        reverse=False)
  elif options.split_by == inputs.SplitBy.HIGH_EXPLICIT_FRACTION_IN_TRAIN:
    key = lambda example_group: example_group.context.explicit_fraction
    splits = _split_dataset_by_sorting(
        dataset,
        options.test_fraction,
        options.validation_fraction,
        rng,
        key,
        reverse=True)
  elif options.split_by == inputs.SplitBy.LOW_NUM_RULES_IN_TRAIN:
    key = lambda example_group: len(example_group.context.metadata.rules)
    splits = _split_dataset_by_sorting(
        dataset,
        options.test_fraction,
        options.validation_fraction,
        rng,
        key,
        reverse=False)
  elif options.split_by == inputs.SplitBy.HIGH_NUM_RULES_IN_TRAIN:
    key = lambda example_group: len(example_group.context.metadata.rules)
    splits = _split_dataset_by_sorting(
        dataset,
        options.test_fraction,
        options.validation_fraction,
        rng,
        key,
        reverse=True)
  elif options.split_by == inputs.SplitBy.CONTEXT_AND_OUTPUT_PATTERN:
    splits = _split_dataset_by_context_and_output_pattern(
        dataset, options.test_fraction, options.validation_fraction, rng)
  elif options.split_by == inputs.SplitBy.CONTEXT_AND_OUTPUT_PATTERN_AND_RESHUFFLE:
    splits = _split_dataset_by_context_and_output_pattern_and_reshuffle(
        dataset, options.test_fraction, options.validation_fraction, rng)
  elif options.split_by == inputs.SplitBy.SUBSAMPLE_AND_CONTEXT:
    splits = _split_dataset_by_subsample_and_context(
        dataset, options.test_fraction, options.validation_fraction, rng,
        options.subsample_size)
  elif options.split_by == inputs.SplitBy.COMPOUND_DIVERGENCE:
    splits = _split_dataset_by_compound_divergence(
        dataset, options.test_fraction, options.validation_fraction, rng,
        options.compound_divergence_options, splitting_stats)
  else:
    raise ValueError(f'SplitOption.split_by {options.split_by} not supported.')
  similarity_metadata.populate_train_similarity_metadata(splits)
  return splits, splitting_stats


def generate_summary_message(spec,
                             stats):
  """Returns a summary message appropriate for writing to logs or to stderr."""
  messages = []
  messages.append('DATASET SPEC:\n%s' % spec)
  messages.append('COUNTERS:\n%s' % stats.counters)
  messages.append('  Valid fraction: %.3f (%d of %d)' %
                  (stats.counters.example_attempts.get_valid_fraction(),
                   stats.counters.example_attempts.valid,
                   stats.counters.example_attempts.get_total()))
  messages.append('ElapsedTime (generate_dataset): %.3fs' %
                  stats.timing.generate_dataset)
  messages.append('ElapsedTime (split_dataset): %.3fs' %
                  stats.timing.split_dataset)
  messages.append('ElapsedTime (total): %.3fs' % stats.timing.total)

  return '\n'.join(messages)


def _populate_example_groups_with_more_examples(
    example_set,
    dataset_spec,
    enable_remote_dependencies = False):
  """Populates the example set with more top level examples."""
  # Load fixed phrase structure grammar
  phrase_structure_grammar = (
      load_fixed_phrase_structure_grammar_template_for_dataset_spec(
          dataset_spec))
  pass_through_rules = []
  for rule in phrase_structure_grammar.get_all_rules():
    rule_string = rule.to_rule_string()
    rule_production = nltk_utils.production_from_production_string(rule_string)
    if nltk_utils.is_pass_through_rule(rule_production):
      pass_through_rules.append(rule_production)
  # Create a new examples_by_request_type schedule with the target requests per
  # context set to additional_test_and_validation_requests_per_context.
  target_num_requests_per_context = (
      dataset_spec.generation_options.sampling
      .additional_test_and_validation_requests_per_context)
  num_examples_by_type_schedule = (
      dataset_spec.generation_options.sampling
      .calculate_schedule_of_examples_by_type(target_num_requests_per_context))
  for example_group, num_examples_by_type in zip(example_set.example_groups,
                                                 num_examples_by_type_schedule):
    counters = dataset_io.counters_from_grouped_example_set(
        cl.GroupedExampleSet([example_group]))
    _populate_a_single_example_group_with_more_examples(
        (example_group, dataset_spec, counters, num_examples_by_type),
        enable_remote_dependencies=enable_remote_dependencies,
        allow_input_mutation=True)


def load_fixed_phrase_structure_grammar_template_for_dataset_spec(
    spec):
  """Returns the grammar template specified in the given dataset spec.

  If the dataset spec does not specify any template ID, then returns None.

  Args:
    spec: The dataset spec of interest.

  Raises:
    ValueError: If the standard grammar specified in the dataset spec did not
      correspond to a valid GrammarSchema.
  """
  if spec.template_grammar_id is None:
    return None

  grammar = grammar_loader.load_standard_grammar(spec.template_grammar_id)
  full_grammar_schema = (
      grammar_representation.grammar_schema_from_feature_grammar(grammar))
  full_grammar_schema.validate(spec.generation_options.grammar)
  return (
      grammar_generation.fixed_phrase_structure_template_from_grammar_schema(
          full_grammar_schema))


def generate_benchmark_from_spec(
    benchmark_dir,
    should_overwrite,
    dataset_spec,
    beam_runner = None,
    enable_remote_dependencies = False,
):
  """Generates a single Conceptual SCAN benchmark based on the given spec.

  Note: In most cases, callers should call `generate_benchmark` rather than
  calling this method directly.

  Args:
    benchmark_dir: Directory to which the benchmark will be written.
    should_overwrite: Whether to overwrite files that already exist.
    dataset_spec: DatasetSpec to use.
    beam_runner: Apache Beam runner for use in steps that are parallelized. If
      not specified, then will perform the generation locally without
      parallelization.
    enable_remote_dependencies: Whether to enable dependencies on remote
      services such as the T5X tokenizer. If False, then features that depend on
      such services (such as the tracking of example input and output lengths in
      COMPACT format using the T5X tokenizer) will be disabled, but dataset
      generation will proceed normally otherwise. Should generally be set to
      False in unit tests, where external network dependencies are normally not
      allowed. Note that this flag does not affect access to benchmark_dir,
      which the caller is free to configure to a local directory location when
      running from unit tests.

  Returns:
    Bundle of statistics about the benchmark generation run.
  """
  logging.info('Starting generation of Conceptual SCAN with spec: %s',
               dataset_spec)
  logging.info('Results will be written to: %s (should_overwrite=%s)',
               benchmark_dir, should_overwrite)

  options = dataset_spec.generation_options
  rng = np.random.RandomState(
      dataset_spec.generation_options.random_seed_or_timestamp())
  if should_overwrite:
    stats = outputs.GenerationStats()
  else:
    stats = dataset_io.read_stats(benchmark_dir)

  phase_start_time = time.time()
  if stats.timing.generate_dataset:
    logging.info('Dataset already exists. Skipping dataset generation phase.')
  else:
    dataset_io.write_dataset_spec(benchmark_dir, dataset_spec)
    if beam_runner:
      beam_runner.run(
          generate_and_write_dataset_pipeline(
              benchmark_dir,
              dataset_spec,
              enable_remote_dependencies=enable_remote_dependencies)
      ).wait_until_finish()
      stats.counters = dataset_io.read_dataset_counters(benchmark_dir)
    else:
      dataset = dataset_generation.generate_dataset(
          options,
          stats.counters,
          rng,
          load_fixed_phrase_structure_grammar_template_for_dataset_spec(
              dataset_spec),
          enable_remote_dependencies=enable_remote_dependencies)
      dataset_io.write_dataset(benchmark_dir, dataset)
      dataset_io.write_counters(benchmark_dir, stats.counters)
    stats.timing.generate_dataset += time.time() - phase_start_time
    dataset_io.write_timing(benchmark_dir, stats.timing)

  phase_start_time = time.time()
  if stats.timing.split_dataset:
    logging.info('Splits already exists. Skipping dataset split phase.')
  else:
    if beam_runner:
      beam_runner.run(split_dataset_pipeline(
          dataset_spec, benchmark_dir)).wait_until_finish()
      stats.splitting_stats = dataset_io.read_splitting_stats(benchmark_dir)
      splits = dataset_io.read_splits(benchmark_dir)
    else:
      dataset = dataset_io.read_dataset(benchmark_dir)
      splits, splitting_stats = split_dataset_and_compute_stats(
          dataset, options.splitting, rng)
      stats.splitting_stats = splitting_stats
      dataset_io.write_splits(benchmark_dir, splits)
      dataset_io.write_splitting_stats(benchmark_dir, splitting_stats)
    stats.timing.split_dataset += time.time() - phase_start_time
    # Check if sampling.additional_test_and_validation_requests_per_context is
    # bigger than 0, if so generate more examples for both the validation and
    # test sets.
    if (options.sampling.additional_test_and_validation_requests_per_context
       ) > 0:
      if beam_runner:
        # Sequentially kick off two beam pipelines, one for each split.
        test_runner = beam_runner.run(
            populate_dataset_with_more_examples_per_context_pipeline(
                benchmark_dir,
                splits['test'],
                'test',
                dataset_spec,
                enable_remote_dependencies=enable_remote_dependencies))
        validation_runner = beam_runner.run(
            populate_dataset_with_more_examples_per_context_pipeline(
                benchmark_dir,
                splits['validation'],
                'validation',
                dataset_spec,
                enable_remote_dependencies=enable_remote_dependencies))
        test_runner.wait_until_finish()
        validation_runner.wait_until_finish()
        # Write the train split separately.
        dataset_io.write_splits(benchmark_dir, {'train': splits['train']})
      else:
        _populate_example_groups_with_more_examples(
            splits['test'],
            dataset_spec=dataset_spec,
            enable_remote_dependencies=enable_remote_dependencies)
        _populate_example_groups_with_more_examples(
            splits['validation'],
            dataset_spec=dataset_spec,
            enable_remote_dependencies=enable_remote_dependencies)
        dataset_io.write_splits(benchmark_dir, splits)
    else:
      dataset_io.write_splits(benchmark_dir, splits)
    dataset_io.write_timing(benchmark_dir, stats.timing)


  summary_message = generate_summary_message(dataset_spec, stats)
  logging.info(summary_message)
  return stats


def generate_benchmark(
    benchmark_dir,
    should_overwrite,
    dataset_spec_id,
    replica_index = None,
    beam_runner = None,
    deterministic = True,
    enable_remote_dependencies = False,
):
  """Generates a single Conceptual SCAN benchmark based on the given spec ID.

  Args:
    benchmark_dir: Directory to which the benchmark will be written.
    should_overwrite: Whether to overwrite files that already exist.
    dataset_spec_id: Id of the DatasetSpec to use. Will automatically search for
      a spec with the given id in`data/dataset_specs.json`.
    replica_index: Dataset replica index (starting from 1), in the case where
      multiple different random datasets are to be generated for the same
      dataset spec. If specified, then will adjust the random seed specified in
      the dataset spec so that a different (but deterministically selected)
      random seed will be used for each replica.
    beam_runner: Apache Beam runner for use in steps that are parallelized. If
      not specified, then will perform the generation locally without
      parallelization.
    deterministic: If True, then will use the fixed random seed defined in the
      DatasetSpec, so as to guarantee that repeated generation runs with the
      same spec will yield the same results. If False, then will use a
      dynamically-chosen random seed based on the timestamp, which sacrifices
      reproducibility in order to avoid Apache Beam stragglers that can occur
      when a given random seed leads to a problematic grammar/context.
    enable_remote_dependencies: Whether to enable dependencies on remote
      services such as the T5X tokenizer. (See explanation in
      generate_benchmark_from_spec)

  Returns:
    Bundle of statistics about the benchmark generation run.
  """
  dataset_spec = dataset_spec_loader.load_dataset_spec(dataset_spec_id)
  if not deterministic:
    dataset_spec = dataclasses.replace(
        dataset_spec,
        generation_options=dataclasses.replace(
            dataset_spec.generation_options,
            random_seed=0,
            random_seed_same_as=None))
  if replica_index:
    random_seed = dataset_spec.generation_options.random_seed
    rng = np.random.RandomState(random_seed)
    # Generate a series of random numbers such that the Nth number in the series
    # will become the new random seed for the Nth replica.
    for _ in range(replica_index):
      # RandomState seed can be 'any integer between 0 and 2**32 - 1 inclusive'.
      # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html
      if not dataset_spec.generation_options.use_timestamp_for_random_seed():
        random_seed = rng.randint(0, 2**32)

    dataset_spec = dataclasses.replace(
        dataset_spec,
        generation_options=dataclasses.replace(
            dataset_spec.generation_options, random_seed=random_seed))

  return generate_benchmark_from_spec(
      benchmark_dir=benchmark_dir,
      should_overwrite=should_overwrite,
      dataset_spec=dataset_spec,
      beam_runner=beam_runner,
      enable_remote_dependencies=enable_remote_dependencies)
