# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic behavior tests for the dataset_generation module.

This file contains tests that verifies the dataset_generation module's basic
behavior.  Tests for more specific edge cases (controlled by sampling options)
and tests for counters are in dataset_generation_additional_test.py.
"""

import traceback
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import nltk
import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import dataset_generation
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import grammar_loader
from conceptual_learning.cscan import induction
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import outputs
from conceptual_learning.cscan import sampling
from conceptual_learning.cscan import test_utils


class GenerateContextTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @mock.patch.object(
      dataset_generation.inference,
      'InferenceEngine',
      return_value=test_utils.make_fake_inference_engine())
  def test_generate_context_in_interpretation_rule_format(self, unused_mock):
    grammar = grammar_loader.load_standard_grammar(
        grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)
    options = test_utils.create_sampling_options(
        explicit_fraction=1.0, rule_format=enums.RuleFormat.INTERPRETATION_RULE)
    counters = outputs.GenerationCounters()
    simple_example_generator = sampling.ExampleGenerator(
        grammar, self.rng, options, counters)
    context, _ = dataset_generation._generate_context_and_inference_engine(
        simple_example_generator, options, counters, self.rng)

    expected_rules = [
        '[x1 and x2] = [x1] [x2]', '[x1 after x2] = [x2] [x1]',
        '[x1 twice] = [x1] [x1]', '[x1 thrice] = [x1] [x1] [x1]',
        '[x1 opposite x2] = [x2] [x2] [x1]',
        '[x1 around x2] = [x2] [x1] [x2] [x1] [x2] [x1] [x2] [x1]',
        '[x1 x2] = [x2] [x1]', '[walk] = WALK', '[look] = LOOK', '[run] = RUN',
        '[jump] = JUMP', '[turn] = EMPTY_STRING', '[left] = LTURN',
        '[right] = RTURN'
    ]
    expected_examples = [
        cl.Example(request=rule, reply=cl.RuleReply.TRUE)
        for rule in expected_rules
    ]

    with self.subTest('metadata_rules_should_be_in_interpretation_rule_format'):
      self.assertCountEqual(context.metadata.rules, expected_rules)

    with self.subTest(
        'example_request_should_be_in_interpretation_rule_format'):
      self.assertCountEqual(context._get_examples(), expected_examples)

  def test_should_raise_error_if_failing_context_quality_check(self):
    grammar = grammar_loader.load_standard_grammar(
        grammar_loader.StandardGrammarId.SCAN_FINITE_NYE_STANDARDIZED)
    options = inputs.SamplingOptions(
        omitted_fraction=0.0,
        omitted_fraction_stddev=0.0,
        # The grammar has 14 rules, and we set unreliable_fraction to be high
        # so that exactly one rule will be hidden while all other rules are
        # unreliable, making it impossible to generate a context that passes
        # the quality check.
        unreliable_fraction=0.95,
        unreliable_fraction_stddev=0.0,
        explicit_fraction=0.0,
        explicit_fraction_stddev=0.0,
        num_examples_per_hidden_rule=2,
        rule_format=enums.RuleFormat.INTERPRETATION_RULE,
        max_attempts_per_example=10,
        max_attempts_per_context=1,
        require_hidden_rules_to_satisfy_inductive_bias=True,
        inductive_bias=induction.IllustrativeExamplesInductiveBias())
    counters = outputs.GenerationCounters()
    simple_example_generator = sampling.ExampleGenerator(
        grammar, self.rng, options, counters)
    with self.assertRaisesRegex(
        sampling.MaxAttemptsReachedError, 'Failed to '
        'generate context of good illustration quality.'):
      dataset_generation._generate_context_and_inference_engine(
          simple_example_generator, options, counters, self.rng)

  def test_hidden_rules_should_be_unknown_if_allowed_to_fail_inductive_bias(
      self):
    # Here we use a minimal grammar for simplicity and set up the options such
    # that all of the rules are hidden and none of them will ever be illustrated
    # enough to satisfy the inductive bias. We turn off, however, the
    # requirement that hidden rules always satisfy the inductive bias.
    grammar_string = """% start S
    S[sem=(?x1+?x1)] -> U[sem=?x1] 'twice'
    S[sem=(?x1+?x1+?x1)] -> U[sem=?x1] 'thrice'
    U[sem='WALK'] -> 'walk'
    U[sem='JUMP'] -> 'jump'"""
    grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
    options = inputs.SamplingOptions(
        omitted_fraction=0.0,
        omitted_fraction_stddev=0.0,
        unreliable_fraction=0.0,
        unreliable_fraction_stddev=0.0,
        explicit_fraction=0.0,
        explicit_fraction_stddev=0.0,
        num_examples_per_hidden_rule=1,
        rule_format=enums.RuleFormat.INTERPRETATION_RULE,
        max_attempts_per_example=10,
        max_attempts_per_context=1,
        require_hidden_rules_to_satisfy_inductive_bias=False,
        inductive_bias=induction.IllustrativeSubstitutionsInductiveBias(
            min_illustrative_variable_substitutions=100,
            min_illustrative_outer_substitutions=100))
    counters = outputs.GenerationCounters()
    simple_example_generator = sampling.ExampleGenerator(
        grammar, self.rng, options, counters)
    try:
      context, inference_engine = (
          dataset_generation._generate_context_and_inference_engine(
              simple_example_generator, options, counters, self.rng))
    except Exception:
      self.fail(f'Generation should not have failed: '
                f'{traceback.format_exc()}')

    with self.subTest('context_should_be_successfully_generated'):
      # There are only 4 possible examples in the above grammar, and all of them
      # should be generated, since we need one illustrative example for each of
      # the 4 rules.
      self.assertLen(context.metadata.rules, 4)
      self.assertLen(context, 4)
    with self.subTest('all_rules_should_be_hidden_unknown'):
      self.assertLen(context.metadata.rules, len(context.metadata.hidden_rules))
      self.assertLen(context.metadata.rules,
                     len(context.metadata.hidden_unknown_rules))
      self.assertEmpty(context.metadata.hidden_true_rules)
    with self.subTest('illustrative_examples_should_still_be_monotonic_true'):
      # Since there are no pass-through rules or explicit rules in this grammar,
      # there should be just one monotonic production for each of the examples
      # in the context.
      self.assertLen(inference_engine.monotonic_productions, len(context))
      self.assertTrue(
          all(
              inference_engine.contains_monotonic_production(
                  example.metadata.production) for example in context),
          f'\nContext: {context}\nInference engine: {inference_engine}')
    with self.subTest('none_of_the_hidden_rules_should_be_inducible'):
      self.assertEmpty(inference_engine.defeasible_productions)

  def test_hidden_rules_should_be_true_if_satisfying_inductive_bias(self):
    # Here we use a minimal grammar for simplicity and set up the options such
    # that all of the rules are hidden and they will always be illustrated
    # enough to satisfy the inductive bias.
    grammar_string = """% start S
    S[sem=(?x1+?x1)] -> U[sem=?x1] 'twice'
    S[sem=(?x1+?x1+?x1)] -> U[sem=?x1] 'thrice'
    U[sem='WALK'] -> 'walk'
    U[sem='JUMP'] -> 'jump'"""
    grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
    options = inputs.SamplingOptions(
        omitted_fraction=0.0,
        omitted_fraction_stddev=0.0,
        unreliable_fraction=0.0,
        unreliable_fraction_stddev=0.0,
        explicit_fraction=0.0,
        explicit_fraction_stddev=0.0,
        num_examples_per_hidden_rule=1,
        rule_format=enums.RuleFormat.INTERPRETATION_RULE,
        max_attempts_per_example=10,
        max_attempts_per_context=1,
        require_hidden_rules_to_satisfy_inductive_bias=True,
        inductive_bias=induction.IllustrativeSubstitutionsInductiveBias(
            min_illustrative_variable_substitutions=1,
            min_illustrative_outer_substitutions=1))
    counters = outputs.GenerationCounters()
    simple_example_generator = sampling.ExampleGenerator(
        grammar, self.rng, options, counters)
    try:
      context, inference_engine = (
          dataset_generation._generate_context_and_inference_engine(
              simple_example_generator, options, counters, self.rng))
    except Exception:
      self.fail(f'Generation should not have failed: '
                f'{traceback.format_exc()}')

    with self.subTest('context_should_be_successfully_generated'):
      # There are only 4 possible examples in the above grammar, and all of them
      # should be generated, since we need one illustrative example for each of
      # the 4 rules.
      self.assertLen(context.metadata.rules, 4)
      self.assertLen(context, 4)
    with self.subTest('all_rules_should_be_hidden_true'):
      self.assertLen(context.metadata.rules, len(context.metadata.hidden_rules))
      self.assertLen(context.metadata.rules,
                     len(context.metadata.hidden_true_rules))
      self.assertEmpty(context.metadata.hidden_unknown_rules)
    with self.subTest(
        'defeasible_inference_engine_should_contain_the_hidden_rules'):
      self.assertNotEmpty(inference_engine.defeasible_productions)

  def test_should_raise_error_if_inconsistent_grammar(self):
    # There are two ways of substituting the V nonterminal in the first
    # production with the two other productions, which product different output
    # tokens:
    # S[sem=(?x1+?x2)] -> D[sem=?x1] V[sem=?x2] 'fep', then
    # S[sem=(?x1+?x2+?x3+?x2)] -> D[sem=?x1] D[sem=?x2] D[sem=?x3] 'fep'
    # versus:
    # S[sem=(?x1+?x2)] -> V[sem=?x1] D[sem=?x2] 'fep', then
    # S[sem=(?x1+?x2+?x1+?x3)] -> D[sem=?x1] D[sem=?x2] D[sem=?x3] 'fep'
    grammar_string = """% start S
    S[sem=(?x1+?x2)] -> V[sem=?x1] V[sem=?x2] 'fep'
    V[sem=(?x1+?x2+?x1)] -> D[sem=?x1] D[sem=?x2]
    V[sem=?x1] -> D[sem=?x1]
    """
    grammar = nltk.grammar.FeatureGrammar.fromstring(grammar_string)
    options = inputs.SamplingOptions(
        omitted_fraction=0.0, unreliable_fraction=0.0)
    counters = outputs.GenerationCounters()
    simple_example_generator = sampling.ExampleGenerator(
        grammar, self.rng, options, counters)

    with self.assertRaisesRegex(inference.InconsistencyError,
                                'Adding production would cause inconsistency'):
      dataset_generation._generate_context_and_inference_engine(
          simple_example_generator, options, counters, self.rng)


class ContextQualityCheckTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('require_hidden_rules_to_satisfy_inductive_bias', True),
      ('do_not_require_hidden_rules_to_satisfy_inductive_bias', False))
  def test_should_detect_context_with_poor_illustration_quality(
      self, require_hidden_rules_to_satisfy_inductive_bias):
    rule = 'hidden rule'
    unreliable_rule = 'unreliable rule'
    distractor_rule_0 = 'distractor rule 0'
    distractor_rule_1 = 'distractor rule 1'
    mutable_context = cl.ExampleSet()
    mutable_context.mark_rule_as_unreliable(unreliable_rule)
    # Here the hidden rule is illustrated reliably only 60% of the time by 3
    # reliable examples.  Even if all the unreliable rules in the context are
    # illustrated in at least two different ways, the context should fail the
    # quality check.
    examples = [
        cl.Example(
            request='q1', reply='r1',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q2', reply='r2',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q3', reply='r3',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q4',
            reply='r4',
            metadata=cl.ExampleMetadata(
                rules={rule},
                distractor_rules_by_unreliable_rule=({
                    unreliable_rule: [distractor_rule_0, distractor_rule_1]
                }))),
        cl.Example(
            request='q5',
            reply='r5',
            metadata=cl.ExampleMetadata(
                rules={rule},
                distractor_rules_by_unreliable_rule=({
                    unreliable_rule: [distractor_rule_0, distractor_rule_1]
                }))),
    ]
    mutable_context.add_hidden_rule(rule, examples)
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    options = inputs.SamplingOptions(
        require_hidden_rules_to_satisfy_inductive_bias=(
            require_hidden_rules_to_satisfy_inductive_bias),
        inductive_bias=induction.IllustrativeExamplesInductiveBias())
    counters = outputs.GenerationCounters()

    passed_check = dataset_generation._context_quality_check(
        context, options, counters)
    counters_updated = counters.context_attempts.poor_illustration_quality > 0
    with self.subTest(
        'fails_quality_check_if_hidden_rules_required_to_satisfy_inductive_bias'
    ):
      self.assertEqual(not require_hidden_rules_to_satisfy_inductive_bias,
                       passed_check)
    with self.subTest('updates_counters_only_if_quality_check_failed'):
      self.assertNotEqual(passed_check, counters_updated)

  def test_should_detect_context_not_illustrating_unreliable_rules_sufficiently(
      self):
    rule = 'hidden rule'
    unreliable_rule = 'unreliable rule'
    unreliable_rule_insufficient = 'unreliable rule insufficient'
    distractor_rule_0 = 'distractor rule 0'
    distractor_rule_1 = 'distractor rule 1'
    mutable_context = cl.ExampleSet()
    mutable_context.mark_rule_as_unreliable(unreliable_rule)
    mutable_context.mark_rule_as_unreliable(unreliable_rule_insufficient)
    # The hidden rule is illustrated reliablely sufficiently many times, but
    # one of the unreliable rules is illustrated only in one way across the
    # examples, so the context should fail the quality check.
    examples = [
        cl.Example(
            request='q1', reply='r1',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q2', reply='r2',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q3', reply='r3',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q4', reply='r4',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q5',
            reply='r5',
            metadata=cl.ExampleMetadata(
                rules={rule},
                distractor_rules_by_unreliable_rule=({
                    unreliable_rule: [distractor_rule_0, distractor_rule_1],
                    unreliable_rule_insufficient: [distractor_rule_0]
                }))),
        cl.Example(
            request='q6',
            reply='r6',
            metadata=cl.ExampleMetadata(
                rules={rule},
                distractor_rules_by_unreliable_rule=({
                    unreliable_rule: [distractor_rule_0, distractor_rule_1],
                    # The same unreliable rule could be used multiple times in a
                    # single example.
                    unreliable_rule_insufficient: [
                        distractor_rule_0, distractor_rule_0
                    ]
                }))),
    ]
    mutable_context.add_hidden_rule(rule, examples)
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    options = inputs.SamplingOptions(
        inductive_bias=induction.IllustrativeExamplesInductiveBias())
    counters = outputs.GenerationCounters()

    self.assertFalse(
        dataset_generation._context_quality_check(context, options, counters))


class FilterExamplesIllustratingDistractorRulesTest(parameterized.TestCase):

  @parameterized.named_parameters(('max_not_reached', 2, 2, 0),
                                  ('max_reached', 1, 1, 1))
  def test_should_allow_examples_illustrating_distractor_rule_if_below_max(
      self, max_distractor_rule_illustration, expected_num_yielded_examples,
      expected_failed_example_attempts_counter):
    mutable_context = cl.ExampleSet()
    unreliable_rule = 'unreliable rule'
    distractor_rule = 'distractor rule'
    mutable_context.mark_rule_as_unreliable(unreliable_rule)
    example = cl.Example(
        request='q1',
        reply='r1',
        metadata=cl.ExampleMetadata(
            rules={distractor_rule},
            distractor_rules_by_unreliable_rule={
                unreliable_rule: [distractor_rule]
            }))
    mutable_context.add_unreliable_rule(unreliable_rule, [example])

    with self.subTest('context_metadata_should_record_distractor_rule'):
      self.assertEqual(
          mutable_context.metadata.distractor_rules_by_unreliable_rule,
          {unreliable_rule: [distractor_rule]})

    with self.subTest(
        'context_metadata_should_record_example_using_distractor_rule'):
      self.assertEqual(mutable_context.metadata.examples_by_rule,
                       {distractor_rule: [example]})

    # At this point we have a mutable context using the distractor rule exactly
    # once.
    other_distractor_rule = 'other distractor rule'
    examples = [
        # This example illustrates the distractor rule already in the context,
        # so it is filtered out depending on whether
        # max_distractor_rule_illustration has been reached.
        cl.Example(
            request='q2',
            reply='r2',
            metadata=cl.ExampleMetadata(
                rules={distractor_rule},
                distractor_rules_by_unreliable_rule={
                    unreliable_rule: [distractor_rule]
                })),
        # This example illustrates the same unreliable rule with a different
        # distractor rule, so it is not filtered out.
        cl.Example(
            request='q3',
            reply='r3',
            metadata=cl.ExampleMetadata(
                rules={other_distractor_rule},
                distractor_rules_by_unreliable_rule={
                    unreliable_rule: [other_distractor_rule]
                }))
    ]
    counters = outputs.GenerationCounters()
    yielded_examples = list(
        dataset_generation._filter_examples_illustrating_distractor_rules(
            examples, max_distractor_rule_illustration, mutable_context,
            counters))

    with self.subTest('should_allow_examples_illustrating_distractor_rule'):
      self.assertLen(yielded_examples, expected_num_yielded_examples)

    with self.subTest('should_not_increment_example_attempts'):
      self.assertEqual(
          counters.example_attempts.illustrating_distractor_rule_too_many_times,
          expected_failed_example_attempts_counter)


if __name__ == '__main__':
  absltest.main()
