# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses

from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import induction
from conceptual_learning.cscan import test_utils


class IllustrativeExamplesInductiveBiasTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('sufficient_reliable_examples', 3, True),
      ('insufficient_reliable_examples', 4, False),
  )
  def test_can_induce_rule(self, min_illustrative_examples, expected):
    # This is the rule that we will check whether we can induce.
    rule = 'hidden rule'
    # Here we set up 4 illustrative examples, only 3 of which are reliable.
    other_reliable_rule = 'other reliable rule'
    unreliable_rule = 'unreliable rule'
    distractor_rule = 'distractor rule'
    mutable_context = cl.ExampleSet()
    mutable_context.mark_rule_as_unreliable(unreliable_rule)
    illustrative_examples = [
        cl.Example(
            request='q1', reply='r1',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q2', reply='r2',
            metadata=cl.ExampleMetadata(rules={rule})),
        cl.Example(
            request='q3',
            reply='r3',
            metadata=cl.ExampleMetadata(rules={rule, other_reliable_rule})),
        cl.Example(
            request='q4',
            reply='r4',
            metadata=cl.ExampleMetadata(
                rules={rule, distractor_rule},
                distractor_rules_by_unreliable_rule=({
                    unreliable_rule: [distractor_rule]
                }))),
    ]
    mutable_context.add_hidden_rule(rule, illustrative_examples)
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = induction.IllustrativeExamplesInductiveBias(
        min_illustrative_examples=min_illustrative_examples)

    self.assertEqual(inductive_bias.can_induce_rule(rule, context), expected)


class IllustrativeSubstitutionsInductiveBiasTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('sufficient_substitutions_for_all_variables', 2, True),
      ('insufficient_substitutions_for_one_variable', 3, False),
      ('insufficient_substitutions_for_all_variables', 4, False),
  )
  def test_min_variable_substitutions(
      self, min_variable_substitutions, expected):
    rule = '[x1 and x2] = [x1] [x2]'
    mutable_context = cl.ExampleSet(
        metadata=cl.ExampleSetMetadata(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            variable_substitutions_by_rule={
                rule: {
                    'x1': {'a', 'b', 'c'},
                    'x2': {'d', 'e'},
                },
            },
            outer_substitutions_by_rule={rule: {'__'}}))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = induction.IllustrativeSubstitutionsInductiveBias(
        min_illustrative_variable_substitutions=min_variable_substitutions,
        min_illustrative_outer_substitutions=1,
        reliable_examples_only=False)
    self.assertEqual(inductive_bias.can_induce_rule(rule, context), expected)

  @parameterized.named_parameters(
      ('bias_not_satisfied_by_unreliable_variable_substitutions', {'x1'},
       {'__'}, {}, {'__'}, False),
      ('bias_not_satisfied_by_unreliable_outer_substitutions', {'x1'}, {'__'},
       {'x1'}, {}, False),
      ('bias_not_satisfied_by_reliable_substitutions', {'x1'}, {'__'}, {'x1'},
       {'__'}, True),
  )
  def test_min_variable_substitutions_with_reliable_examples_only(
      self, variable_substitutions, outer_substitutions,
      reliable_variable_substitutions, reliable_outer_substitutions, expected):
    rule = '[x1 and x2] = [x1] [x2]'
    mutable_context = cl.ExampleSet(
        metadata=cl.ExampleSetMetadata(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            variable_substitutions_by_rule={
                rule: {
                    'x1': variable_substitutions
                },
            },
            outer_substitutions_by_rule={rule: outer_substitutions},
            reliable_variable_substitutions_by_rule={
                rule: {
                    'x1': reliable_variable_substitutions
                },
            },
            reliable_outer_substitutions_by_rule={
                rule: reliable_outer_substitutions
            }))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = induction.IllustrativeSubstitutionsInductiveBias(
        min_illustrative_variable_substitutions=1,
        min_illustrative_outer_substitutions=1,
        reliable_examples_only=True)

    with self.subTest(
        'bias_not_satisfied_by_unreliable_variable_substitutions'):
      self.assertEqual(inductive_bias.can_induce_rule(rule, context), expected)

  def test_rule_with_no_variables(self):
    """In this case, the variable requirement is trivially satisfied."""
    rule = '[jump] = JUMP'
    mutable_context = cl.ExampleSet(
        metadata=cl.ExampleSetMetadata(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            variable_substitutions_by_rule={rule: {}},
            outer_substitutions_by_rule={rule: {'__'}}))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = induction.IllustrativeSubstitutionsInductiveBias(
        min_illustrative_variable_substitutions=4,
        min_illustrative_outer_substitutions=1,
        reliable_examples_only=False)
    self.assertTrue(inductive_bias.can_induce_rule(rule, context))

  def test_variable_illustrated_in_unsubstituted_form(self):
    """This satisfies the variable requirement even if the only illustration."""
    rule = '[x1 twice] = [x1] [x1]'
    mutable_context = cl.ExampleSet(
        metadata=cl.ExampleSetMetadata(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            variable_substitutions_by_rule={
                rule: {
                    'x1': {'x2'}
                },
            },
            outer_substitutions_by_rule={rule: {'__'}}))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = induction.IllustrativeSubstitutionsInductiveBias(
        min_illustrative_variable_substitutions=4,
        min_illustrative_outer_substitutions=1,
        reliable_examples_only=False)
    self.assertTrue(inductive_bias.can_induce_rule(rule, context))

  @parameterized.named_parameters(
      ('sufficient_outer_substitutions', 2, True),
      ('insufficient_outer_substitutions', 3, False),
  )
  def test_min_outer_substitutions(self, min_outer_substitutions, expected):
    rule = '[jump] = [JUMP]'
    mutable_context = cl.ExampleSet(
        metadata=cl.ExampleSetMetadata(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            variable_substitutions_by_rule={rule: {}},
            outer_substitutions_by_rule={rule: {'__ x', 'y __'}}))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = induction.IllustrativeSubstitutionsInductiveBias(
        min_illustrative_variable_substitutions=1,
        min_illustrative_outer_substitutions=min_outer_substitutions,
        reliable_examples_only=False)
    self.assertEqual(inductive_bias.can_induce_rule(rule, context), expected)

  def test_rule_illustrated_at_top_of_rule_tree(self):
    """This satisfies the outer requirement even if the only illustration."""
    rule = '[jump] = [JUMP]'
    mutable_context = cl.ExampleSet(
        metadata=cl.ExampleSetMetadata(
            rule_format=enums.RuleFormat.INTERPRETATION_RULE,
            variable_substitutions_by_rule={rule: {}},
            outer_substitutions_by_rule={rule: {'__'}}))
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = induction.IllustrativeSubstitutionsInductiveBias(
        min_illustrative_variable_substitutions=1,
        min_illustrative_outer_substitutions=2,
        reliable_examples_only=False)
    self.assertTrue(inductive_bias.can_induce_rule(rule, context))


@dataclasses.dataclass
class _ConstantInductiveBias(induction.InductiveBias):
  """Trivial inductive bias that always returns the same value.

  Attributes:
    constant_response: Value for `can_induce_rule` to return.
  """
  constant_response: bool

  def can_induce_rule(self, rule, context):
    """See parent class."""
    return self.constant_response


class GetRuleIllustrationQualityTest(parameterized.TestCase):

  def test_explicit_rule_has_good_quality_regardless_of_inductive_bias(self):
    rule = 'explicit rule'
    mutable_context = cl.ExampleSet()
    example = test_utils.create_example_from_explicit_rule(rule)
    mutable_context.add_explicit_rule(rule, example)
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = _ConstantInductiveBias(False)

    expected = cl.IllustrationQuality.GOOD
    self.assertEqual(
        induction.get_rule_illustration_quality(rule, context, inductive_bias),
        expected)

  @parameterized.named_parameters(
      ('satisfies_inductive_bias', True, cl.IllustrationQuality.GOOD),
      ('does_not_satisfy_inductive_bias', False, cl.IllustrationQuality.POOR),
  )
  def test_hidden_rule_has_good_quality_if_it_satisfies_inductive_bias(
      self, can_induce, expected):
    rule = 'hidden rule'
    mutable_context = cl.ExampleSet()
    illustrative_examples = [
        cl.Example(
            request='q1', reply='r1',
            metadata=cl.ExampleMetadata(rules={rule})),
    ]
    mutable_context.add_hidden_rule(rule, illustrative_examples)
    context = cl.FrozenExampleSet.from_example_set(mutable_context)
    inductive_bias = _ConstantInductiveBias(can_induce)

    self.assertEqual(
        induction.get_rule_illustration_quality(rule, context, inductive_bias),
        expected)


if __name__ == '__main__':
  absltest.main()
