# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Conceptual SCAN dataset spec definitions."""

from typing import Mapping, Sequence

import immutabledict

from conceptual_learning.cscan import enums
from conceptual_learning.cscan import grammar_loader
from conceptual_learning.cscan import inputs

_MAX_LENGTHS_4096_512 = inputs.ExampleLengthOptions(
    max_input_length_standard=4096,
    max_input_length_compact=4096,
    max_output_length_standard=512,
    max_output_length_compact=512)

_MAX_LENGTHS_4096_256 = inputs.ExampleLengthOptions(
    max_input_length_standard=4096,
    max_input_length_compact=4096,
    max_output_length_standard=256,
    max_output_length_compact=256)

_MAX_LENGTHS_2048_256 = inputs.ExampleLengthOptions(
    max_input_length_standard=2048,
    max_input_length_compact=2048,
    max_output_length_standard=256,
    max_output_length_compact=256)

# Collection of all valid dataset specs.
# NEXT RANDOM SEED: 73
_DATASET_SPECS: Sequence[inputs.DatasetSpec] = (
    inputs.DatasetSpec(
        id='base',
        description=(
            'Base cSCAN dataset based on a grammar space close to the original '
            'SCAN grammar. For every context, the phrase structure grammar is '
            'equivalent to the one used in the original SCAN task, while the '
            'set of interpretation rules underlying each context is of roughly '
            'similar complexity to the original SCAN rule set.'),
        generation_options=inputs.GenerationOptions(
            random_seed=1,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=1200,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                lengths=_MAX_LENGTHS_2048_256),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 1200, test_fraction=100 / 1200)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='base_mcd',
        description=(
            'Like base, but performs a top-level example maximum compound '
            'divergence split, where the atoms are the lefthand sides (i.e., '
            'input portions) of the interpretation rules that would have been '
            'composed to form the top-level request, and the compounds are the '
            'compositions of these atoms that occur in that request.'),
        generation_options=inputs.GenerationOptions(
            random_seed=69,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=12000,
                num_requests_per_context=100,
                lengths=_MAX_LENGTHS_2048_256),
            splitting=inputs.SplitOptions(
                split_by=inputs.SplitBy.COMPOUND_DIVERGENCE,
                compound_divergence_options=inputs.CompoundDivergenceOptions(
                    use_rule_pattern=True,
                    use_insertion_deletion=True,
                    top_level_example=True,
                    composition_compound=True,
                    filter_contexts=False,
                    output_fraction=0.1,
                    sample_size=20),
                validation_fraction=100 / 1200,
                test_fraction=100 / 1200)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='extended',
        description=(
            'Extended cSCAN dataset. It is based on a richer grammar space and '
            'uses rule sampling to avoid too large contexts. Allows for longer '
            'output sequences for individual interpretation rules, while '
            'limiting the cumulative output sequence size.'),
        generation_options=inputs.GenerationOptions(
            random_seed=63,
            grammar=inputs.GrammarOptions(
                input_vocabulary=frozenset(
                    list(inputs.SCAN_INPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_INPUT_VOCABULARY)),
                output_vocabulary=frozenset(
                    list(inputs.SCAN_OUTPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_OUTPUT_VOCABULARY)),
                max_repetitions_per_token_in_output_sequence=6,
                max_output_sequence_size=8,
                max_cumulative_output_sequence_size=40),
            sampling=inputs.SamplingOptions(
                num_contexts=1200,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                explicit_fraction=0.8,
                explicit_fraction_stddev=0.2,
                min_explicit_fraction=0.6,
                lengths=_MAX_LENGTHS_4096_512),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 1200, test_fraction=100 / 1200)),
        template_grammar_id=(grammar_loader.StandardGrammarId.SCAN_EXTENDED)),
    inputs.DatasetSpec(
        id='extended_mcd',
        description=(
            'Like extended, but performs a top-level example maximum compound '
            'divergence split, using the same splitting method as base_mcd.'),
        generation_options=inputs.GenerationOptions(
            random_seed=70,
            grammar=inputs.GrammarOptions(
                input_vocabulary=frozenset(
                    list(inputs.SCAN_INPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_INPUT_VOCABULARY)),
                output_vocabulary=frozenset(
                    list(inputs.SCAN_OUTPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_OUTPUT_VOCABULARY)),
                max_repetitions_per_token_in_output_sequence=6,
                max_output_sequence_size=8,
                max_cumulative_output_sequence_size=40),
            sampling=inputs.SamplingOptions(
                num_contexts=8000,
                num_requests_per_context=100,
                explicit_fraction=0.8,
                explicit_fraction_stddev=0.2,
                min_explicit_fraction=0.6,
                lengths=_MAX_LENGTHS_4096_512),
            splitting=inputs.SplitOptions(
                split_by=inputs.SplitBy.COMPOUND_DIVERGENCE,
                compound_divergence_options=inputs.CompoundDivergenceOptions(
                    use_rule_pattern=True,
                    use_insertion_deletion=True,
                    top_level_example=True,
                    composition_compound=True,
                    filter_contexts=False,
                    output_fraction=0.15,
                    sample_size=20),
                validation_fraction=100 / 1200,
                test_fraction=100 / 1200)),
        template_grammar_id=(grammar_loader.StandardGrammarId.SCAN_EXTENDED)),
    inputs.DatasetSpec(
        id='base_100_contexts',
        description=('Like base, but with 100 train contexts.'),
        generation_options=inputs.GenerationOptions(
            random_seed=32,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=300,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                lengths=_MAX_LENGTHS_2048_256),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 300, test_fraction=100 / 300)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='base_5000_contexts',
        description=('Like base, but with 5000 train contexts.'),
        generation_options=inputs.GenerationOptions(
            random_seed=59,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=5200,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                lengths=_MAX_LENGTHS_2048_256),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 5200, test_fraction=100 / 5200)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='base_8K_contexts',
        description=('Like base, but with 10K train contexts.'),
        generation_options=inputs.GenerationOptions(
            random_seed=33,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=8200,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                lengths=_MAX_LENGTHS_2048_256),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 8200, test_fraction=100 / 8200)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='base_100_contexts_1000_examples',
        description=(
            'Has the same number of top-level examples as base, but with 100 '
            'train contexts and 1000 requests per context.'),
        generation_options=inputs.GenerationOptions(
            random_seed=71,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=300,
                num_requests_per_context=1000,
                lengths=_MAX_LENGTHS_2048_256),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 300, test_fraction=100 / 300)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='base_10K_contexts_10_examples',
        description=(
            'Has the same number of top-level examples as base, but with 10K '
            'train contexts and 10 requests per context.'),
        generation_options=inputs.GenerationOptions(
            random_seed=61,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=10200,
                num_requests_per_context=10,
                additional_test_and_validation_requests_per_context=990,
                lengths=_MAX_LENGTHS_2048_256),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 10200, test_fraction=100 / 10200)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='extended_100_contexts',
        description=('Like extended, but with 100 train contexts.'),
        generation_options=inputs.GenerationOptions(
            random_seed=62,
            grammar=inputs.GrammarOptions(
                input_vocabulary=frozenset(
                    list(inputs.SCAN_INPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_INPUT_VOCABULARY)),
                output_vocabulary=frozenset(
                    list(inputs.SCAN_OUTPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_OUTPUT_VOCABULARY)),
                max_repetitions_per_token_in_output_sequence=6,
                max_output_sequence_size=8,
                max_cumulative_output_sequence_size=40),
            sampling=inputs.SamplingOptions(
                num_contexts=300,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                explicit_fraction=0.8,
                explicit_fraction_stddev=0.2,
                min_explicit_fraction=0.6,
                lengths=_MAX_LENGTHS_4096_512),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 300, test_fraction=100 / 300)),
        template_grammar_id=(grammar_loader.StandardGrammarId.SCAN_EXTENDED)),
    inputs.DatasetSpec(
        id='extended_5000_contexts',
        description=('Like extended, but with 5000 train contexts.'),
        generation_options=inputs.GenerationOptions(
            random_seed=64,
            grammar=inputs.GrammarOptions(
                input_vocabulary=frozenset(
                    list(inputs.SCAN_INPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_INPUT_VOCABULARY)),
                output_vocabulary=frozenset(
                    list(inputs.SCAN_OUTPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_OUTPUT_VOCABULARY)),
                max_repetitions_per_token_in_output_sequence=6,
                max_output_sequence_size=8,
                max_cumulative_output_sequence_size=40),
            sampling=inputs.SamplingOptions(
                num_contexts=5200,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                explicit_fraction=0.8,
                explicit_fraction_stddev=0.2,
                min_explicit_fraction=0.6,
                lengths=_MAX_LENGTHS_4096_512),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 5200, test_fraction=100 / 5200)),
        template_grammar_id=(grammar_loader.StandardGrammarId.SCAN_EXTENDED)),
    inputs.DatasetSpec(
        id='extended_8K_contexts',
        description=('Like extended, but with 10K train contexts.'),
        generation_options=inputs.GenerationOptions(
            random_seed=55,
            grammar=inputs.GrammarOptions(
                input_vocabulary=frozenset(
                    list(inputs.SCAN_INPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_INPUT_VOCABULARY)),
                output_vocabulary=frozenset(
                    list(inputs.SCAN_OUTPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_OUTPUT_VOCABULARY)),
                max_repetitions_per_token_in_output_sequence=6,
                max_output_sequence_size=8,
                max_cumulative_output_sequence_size=40),
            sampling=inputs.SamplingOptions(
                num_contexts=8200,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                explicit_fraction=0.8,
                explicit_fraction_stddev=0.2,
                min_explicit_fraction=0.6,
                lengths=_MAX_LENGTHS_4096_512),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 8200, test_fraction=100 / 8200)),
        template_grammar_id=(grammar_loader.StandardGrammarId.SCAN_EXTENDED)),
    inputs.DatasetSpec(
        id='extended_100_contexts_1000_examples',
        description=(
            'Has the same number of top-level examples as extended, but with '
            '100 train contexts and 1000 requests per context.'),
        generation_options=inputs.GenerationOptions(
            random_seed=72,
            grammar=inputs.GrammarOptions(
                input_vocabulary=frozenset(
                    list(inputs.SCAN_INPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_INPUT_VOCABULARY)),
                output_vocabulary=frozenset(
                    list(inputs.SCAN_OUTPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_OUTPUT_VOCABULARY)),
                max_repetitions_per_token_in_output_sequence=6,
                max_output_sequence_size=8,
                max_cumulative_output_sequence_size=40),
            sampling=inputs.SamplingOptions(
                num_contexts=300,
                num_requests_per_context=1000,
                explicit_fraction=0.8,
                explicit_fraction_stddev=0.2,
                min_explicit_fraction=0.6,
                lengths=_MAX_LENGTHS_4096_512),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 300, test_fraction=100 / 300)),
        template_grammar_id=(grammar_loader.StandardGrammarId.SCAN_EXTENDED)),
    inputs.DatasetSpec(
        id='extended_10K_contexts_10_examples',
        description=(
            'Has the same number of top-level examples as extended, but with '
            '10K train contexts and 10 requests per context.'),
        generation_options=inputs.GenerationOptions(
            random_seed=66,
            grammar=inputs.GrammarOptions(
                input_vocabulary=frozenset(
                    list(inputs.SCAN_INPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_INPUT_VOCABULARY)),
                output_vocabulary=frozenset(
                    list(inputs.SCAN_OUTPUT_VOCABULARY
                         | inputs.SCAN_EXTENDED_OUTPUT_VOCABULARY)),
                max_repetitions_per_token_in_output_sequence=6,
                max_output_sequence_size=8,
                max_cumulative_output_sequence_size=40),
            sampling=inputs.SamplingOptions(
                num_contexts=10200,
                num_requests_per_context=10,
                additional_test_and_validation_requests_per_context=990,
                explicit_fraction=0.8,
                explicit_fraction_stddev=0.2,
                min_explicit_fraction=0.6,
                lengths=_MAX_LENGTHS_4096_512),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 10200, test_fraction=100 / 10200)),
        template_grammar_id=(grammar_loader.StandardGrammarId.SCAN_EXTENDED)),
    inputs.DatasetSpec(
        id='base_natural_language',
        description=('Like base, but uses a natural language rule format.'),
        generation_options=inputs.GenerationOptions(
            random_seed=1,
            random_seed_same_as='base',
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=1200,
                num_requests_per_context=100,
                additional_test_and_validation_requests_per_context=900,
                num_rules=-1,
                rule_format=enums.RuleFormat.NATURAL_LANGUAGE,
                lengths=_MAX_LENGTHS_4096_256),
            splitting=inputs.SplitOptions(
                validation_fraction=100 / 1200, test_fraction=100 / 1200)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='small_with_validation_and_test',
        description=(
            'Small dataset for use in testing, which is just big enough that '
            'the validation and test splits are non-empty.'),
        generation_options=inputs.GenerationOptions(
            random_seed=23,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=5,
                num_requests_per_context=2,
                additional_test_and_validation_requests_per_context=8),
            splitting=inputs.SplitOptions(
                validation_fraction=0.2, test_fraction=0.2)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='small_without_validation_or_test',
        description=(
            'Small dataset for use in testing, which contains so few contexts '
            'that the validation and test splits end up being empty.'),
        generation_options=inputs.GenerationOptions(
            random_seed=24,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=1,
                num_requests_per_context=10,
                omitted_fraction=0.0,
                omitted_fraction_stddev=0.0,
                unreliable_fraction=0.0,
                unreliable_fraction_stddev=0.0)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='small_with_unreliable_and_omitted',
        description=('Small dataset with unreliable and omitted rules for use '
                     'in testing.'),
        generation_options=inputs.GenerationOptions(
            random_seed=25,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=1,
                num_requests_per_context=10,
                omitted_fraction=0.1,
                omitted_fraction_stddev=0.0,
                # Setting unreliable_fraction to be 0.3 would result in too many
                # poor quality contexts.
                unreliable_fraction=0.1,
                unreliable_fraction_stddev=0.0)),
        template_grammar_id=(
            grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)),
    inputs.DatasetSpec(
        id='feature_grammar_small_without_template',
        description=('Like small_without_validation_or_test, but with '
                     'FEATURE_GRAMMAR_PRODUCTION as the rule format and '
                     'without a template_grammar_id. These are settings we '
                     'used early on for small scale test runs, prior to '
                     'adoption of the interpretation rule format.'),
        generation_options=inputs.GenerationOptions(
            random_seed=28,
            grammar=inputs.GrammarOptions(
                max_repetitions_per_token_in_output_sequence=2),
            sampling=inputs.SamplingOptions(
                num_contexts=1,
                num_requests_per_context=10,
                num_rules=-1,
                omitted_fraction=0.0,
                omitted_fraction_stddev=0.0,
                unreliable_fraction=0.0,
                unreliable_fraction_stddev=0.0,
                rule_format=enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION))),
)

# Mapping of dataset spec ID to its corresponding spec."
DATASET_SPEC_BY_ID: Mapping[str, inputs.DatasetSpec] = (
    immutabledict.immutabledict({spec.id: spec for spec in _DATASET_SPECS}))
