# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Iterable, Sequence

from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import similarity_metrics

# A few examples for test purposes. The exact contents don't matter, as long as
# the examples are different from one another.
_EXAMPLE1 = cl.Example(request='q1', reply='r1')
_EXAMPLE2 = cl.Example(request='q2', reply='r2')
_EXAMPLE3 = cl.Example(request='q3', reply='r3')
_EXAMPLE4 = cl.Example(request='q4', reply='r4')
_EXAMPLE5 = cl.Example(request='q5', reply='r5')
_EXAMPLE6 = cl.Example(request='q6', reply='r6')
_EXAMPLE7 = cl.Example(request='q7', reply='r7')
_EXAMPLE8 = cl.Example(request='q8', reply='r8')


def _create_example_with_rules(request = 'r',
                               reply = 'q',
                               rules = ()):
  return cl.Example(
      request=request,
      reply=reply,
      metadata=cl.ExampleMetadata(rules=set(rules)))


def _create_context(rules = (),
                    examples = ()):
  mutable_context = cl.ExampleSet()
  for example in examples:
    mutable_context.add_example(example)
  for rule in rules:
    mutable_context.add_omitted_rule(rule)
  return cl.FrozenExampleSet.from_example_set(mutable_context)


class SimilarityMetricsTest(parameterized.TestCase):

  def _assert_similarity_matrix_matches_results_of_similarity_function(
      self, similarity_matrix,
      validation_or_test_set,
      train_set,
      similarity_function
  ):
    for test_index, similarity_matrix_row in enumerate(similarity_matrix):
      test_context = validation_or_test_set.example_groups[test_index].context
      for train_index, similarity_from_matrix in enumerate(
          similarity_matrix_row):
        train_context = train_set.example_groups[train_index].context
        similarity_from_function = similarity_function(test_context,
                                                       train_context)
        self.assertAlmostEqual(
            similarity_from_function,
            similarity_from_matrix,
            places=3,
            msg=(f'similarity_matrix[{test_index}][{train_index}] = '
                 f'{similarity_from_matrix} while similarity_function applied '
                 f'to the {test_index}th test context and {train_index}th '
                 f'train context = {similarity_from_matrix}'))

  # pyformat: disable
  @parameterized.named_parameters(
      ('context1_contains_no_rules',
       cl.FrozenExampleSet.from_examples([_EXAMPLE1]),
       cl.FrozenExampleSet.from_examples([_EXAMPLE2]), 1.0),
      ('all_rules_from_context1_are_in_context2',
       _create_context(rules=['a']),
       _create_context(rules=['a', 'b']), 1.0),
      ('some_rules_from_context1_are_in_context2',
       _create_context(rules=['a', 'b']),
       _create_context(rules=['a']), 0.5),
      ('no_rules_from_context1_are_in_context2',
       _create_context(rules=['a']),
       _create_context(rules=['b']), 0.0),
  )
  # pyformat: enable
  def test_calculate_context_similarity_by_rules(self, context1, context2,
                                                 expected_similarity):
    self.assertEqual(
        expected_similarity,
        similarity_metrics.calculate_context_similarity_by_rules(
            context1, context2))

  @parameterized.named_parameters(
      ('all_examples_from_context1_are_in_context2',
       cl.FrozenExampleSet.from_examples([_EXAMPLE1]),
       cl.FrozenExampleSet.from_examples([_EXAMPLE1, _EXAMPLE2]), 1.0),)
  def test_calculate_context_similarity_by_examples(self, context1, context2,
                                                    expected_similarity):
    self.assertEqual(
        expected_similarity,
        similarity_metrics.calculate_context_similarity_by_examples(
            context1, context2))

  # pyformat: disable
  @parameterized.named_parameters(
      ('example_contains_no_rules',
       _create_example_with_rules(rules=[]),
       _create_context(rules=['a', 'b']), 1.0),
      ('all_rules_from_example_are_in_context',
       _create_example_with_rules(rules=['a']),
       _create_context(rules=['a', 'b']), 1.0),
      ('some_rules_from_example_are_in_context',
       _create_example_with_rules(rules=['a', 'c']),
       _create_context(rules=['a', 'b', 'd']), 0.5),
      ('no_rules_from_example_are_in_context',
       _create_example_with_rules(rules=['a', 'c']),
       _create_context(rules=['b']), 0.0),
  )
  # pyformat: enable
  def test_calculate_example_to_context_similarity_by_rules(
      self, example, context, expected_similarity):
    self.assertEqual(
        expected_similarity,
        similarity_metrics.calculate_example_to_context_similarity_by_rules(
            example, context))

  def test_context_similarity_matrix_by_rules(self):
    test_set = cl.GroupedExampleSet(example_groups=[
        cl.ExampleGroup(context=_create_context(rules=['a', 'b', 'c', 'd'])),
        cl.ExampleGroup(context=_create_context(rules=['e', 'f'])),
    ])
    train_set = cl.GroupedExampleSet(example_groups=[
        cl.ExampleGroup(context=_create_context(rules=['c', 'd', 'e'])),
        cl.ExampleGroup(context=_create_context(rules=['h'])),
        cl.ExampleGroup(context=_create_context(rules=['a'])),
    ])

    similarity_matrix = (
        similarity_metrics.calculate_context_similarity_matrix_by_rules(
            test_set, train_set))

    with self.subTest('matrix_contains_expected_values'):
      self.assertEqual([[0.5, 0.0, 0.25], [0.5, 0.0, 0.0]],
                       similarity_matrix.tolist())

    with self.subTest('matrix_matches_results_of_similarity_function'):
      self._assert_similarity_matrix_matches_results_of_similarity_function(
          similarity_matrix.tolist(), test_set, train_set,
          similarity_metrics.calculate_context_similarity_by_rules)

  def test_context_similarity_matrix_by_examples(self):
    test_set = cl.GroupedExampleSet(example_groups=[
        cl.ExampleGroup(
            context=_create_context(
                examples=[_EXAMPLE1, _EXAMPLE2, _EXAMPLE3, _EXAMPLE4])),
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE5, _EXAMPLE6])),
    ])
    train_set = cl.GroupedExampleSet(example_groups=[
        cl.ExampleGroup(
            context=_create_context(
                examples=[_EXAMPLE1, _EXAMPLE2, _EXAMPLE3])),
        cl.ExampleGroup(
            context=_create_context(
                examples=[_EXAMPLE4, _EXAMPLE6, _EXAMPLE7])),
        cl.ExampleGroup(context=_create_context(examples=[_EXAMPLE8])),
    ])

    similarity_matrix = (
        similarity_metrics.calculate_context_similarity_matrix_by_examples(
            test_set, train_set))

    with self.subTest('matrix_contains_expected_values'):
      self.assertEqual([[0.75, 0.25, 0.0], [0.0, 0.5, 0.0]],
                       similarity_matrix.tolist())

    with self.subTest('matrix_matches_results_of_similarity_function'):
      self._assert_similarity_matrix_matches_results_of_similarity_function(
          similarity_matrix.tolist(), test_set, train_set,
          similarity_metrics.calculate_context_similarity_by_examples)


if __name__ == '__main__':
  absltest.main()
