# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Iterable

from absl.testing import absltest
from absl.testing import parameterized
import tensorflow_datasets as tfds

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import similarity_metadata

# A few examples for test purposes. The exact contents don't matter, as long as
# the examples are different from one another.
_EXAMPLE1 = cl.Example(request='q1', reply='r1')
_EXAMPLE2 = cl.Example(request='q2', reply='r2')
_EXAMPLE3 = cl.Example(request='q3', reply='r3')
_EXAMPLE4 = cl.Example(request='q4', reply='r4')
_EXAMPLE5 = cl.Example(request='q5', reply='r5')
_EXAMPLE6 = cl.Example(request='q6', reply='r6')
_EXAMPLE7 = cl.Example(request='q7', reply='r7')
_EXAMPLE8 = cl.Example(request='q8', reply='r8')


def _create_context(rules = (),
                    examples = ()):
  mutable_context = cl.ExampleSet()
  for example in examples:
    mutable_context.add_example(example)
  for rule in rules:
    mutable_context.add_omitted_rule(rule)
  return cl.FrozenExampleSet.from_example_set(mutable_context)


class PopulateTrainSimilarityMetadataTest(parameterized.TestCase):

  def test_populate_train_similarity_metadata_basic_behavior(self):
    validation_set = cl.GroupedExampleSet(example_groups=[
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE1, _EXAMPLE2]),
            examples=[_EXAMPLE3]),
    ])
    test_set = cl.GroupedExampleSet(example_groups=[
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE4, _EXAMPLE5]),
            examples=[_EXAMPLE6]),
    ])
    train_set = cl.GroupedExampleSet(example_groups=[
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE1, _EXAMPLE7]),
            examples=[_EXAMPLE6]),
    ])

    splits = {
        tfds.Split.TRAIN: train_set,
        tfds.Split.VALIDATION: validation_set,
        tfds.Split.TEST: test_set,
    }
    original_splits = copy.deepcopy(splits)

    similarity_metadata.populate_train_similarity_metadata(splits)

    with self.subTest('train_set_should_remain_unchanged'):
      self.assertEqual(
          original_splits[tfds.Split.TRAIN].to_string(include_metadata=True),
          splits[tfds.Split.TRAIN].to_string(include_metadata=True))

    with self.subTest('validation_set_should_remain_unchanged_except_metadata'):
      self.assertEqual(
          original_splits[tfds.Split.VALIDATION].to_string(
              include_metadata=False),
          splits[tfds.Split.VALIDATION].to_string(include_metadata=False))

    with self.subTest('validation_set_metadata_should_be_modified'):
      self.assertNotEqual(
          original_splits[tfds.Split.VALIDATION].to_string(
              include_metadata=True),
          splits[tfds.Split.VALIDATION].to_string(include_metadata=True))

    with self.subTest('test_set_should_remain_unchanged_except_metadata'):
      self.assertEqual(
          original_splits[tfds.Split.TEST].to_string(include_metadata=False),
          splits[tfds.Split.TEST].to_string(include_metadata=False))

    with self.subTest('test_set_metadata_should_be_modified'):
      self.assertNotEqual(
          original_splits[tfds.Split.TEST].to_string(include_metadata=True),
          splits[tfds.Split.TEST].to_string(include_metadata=True))

  # pyformat: disable
  @parameterized.named_parameters(
      ('context_rules_half_match_train0_while_reply_matches_train1',
       cl.ExampleGroup(
           context=_create_context(rules=['a', 'g'], examples=[_EXAMPLE8]),
           examples=[cl.Example(request='q1', reply='r1_2')]),
       cl.ExampleSetTrainSimilarityMetadata(
           # Because context rules half match those of train0.
           nearest_similarity_by_rule_overlap=0.5,
           # Because none of the context examples match any train context.
           nearest_similarity_by_example_overlap=0.0)),
      ('context_examples_fully_match_train1_and_reply_matches_train1',
       cl.ExampleGroup(
           context=_create_context(rules=['g'], examples=[_EXAMPLE4]),
           examples=[cl.Example(request='q1', reply='r1_2')]),
       cl.ExampleSetTrainSimilarityMetadata(
           # Because none of the context rules match any train context.
           nearest_similarity_by_rule_overlap=0.0,
           # Because context examples fully match those of train1.
           nearest_similarity_by_example_overlap=1.0)))
  # pyformat: enable
  def test_populate_train_similarity_metadata_yields_correct_context_metadata(
      self, test_example_group, expected_context_train_similarity_metadata):
    test_set = cl.GroupedExampleSet(example_groups=[test_example_group])
    train_set = cl.GroupedExampleSet(example_groups=[
        # train0
        cl.ExampleGroup(
            context=_create_context(
                rules=['a', 'b', 'c'],
                examples=[_EXAMPLE1, _EXAMPLE2, _EXAMPLE3]),
            examples=[cl.Example(request='q1', reply='r1_1')]),
        # train1: different reply from train0
        cl.ExampleGroup(
            context=_create_context(
                rules=['d', 'e'], examples=[_EXAMPLE4, _EXAMPLE5]),
            examples=[cl.Example(request='q1', reply='r1_2')]),
        # train2: different reply from train0 (same as train1)
        cl.ExampleGroup(
            context=_create_context(
                rules=['f'], examples=[_EXAMPLE6, _EXAMPLE7]),
            examples=[cl.Example(request='q1', reply='r1_2')]),
    ])

    splits = {
        tfds.Split.TRAIN: train_set,
        tfds.Split.TEST: test_set,
    }

    similarity_metadata.populate_train_similarity_metadata(splits)

    test_context0 = splits[tfds.Split.TEST].example_groups[0].context
    self.assertEqual(
        expected_context_train_similarity_metadata,
        test_context0.metadata.train_similarity,
        f'Full similarity metadata: {test_context0.metadata.train_similarity}')

  # pyformat: disable
  @parameterized.named_parameters(
      ('context_examples_match_no_examples_from_train',
       cl.ExampleGroup(
           context=_create_context(examples=[_EXAMPLE8]),
           examples=[cl.Example(request='q1', reply='r1_2')]),
       cl.ExampleTrainSimilarityMetadata(
           # train0, train1 and train2
           num_train_examples_with_same_request=3,
           # train1 and train2
           num_train_examples_with_same_request_and_reply=2,
           num_train_examples_with_same_request_and_output=2,
           # r1_1 and r1_2
           num_unique_train_replies=2,
           # Since all train examples have the same similarity (0.0), the first
           # train example (train0) is treated as the nearest neighbor.
           nearest_reply_matches=False,
           nearest_qualifier_matches=False,
           # Similarity is 0.0 if none of the context examples appear in train.
           nearest_similarity=0.0,
           # Consensus value is the same as the nearest value (i.e., the first
           # value) when the scores are all 0.0.
           consensus_reply_matches=False,
           consensus_qualifier_matches=False)),
      ('nearest_value_same_as_consensus_value',
       cl.ExampleGroup(
           context=_create_context(examples=[_EXAMPLE4]),
           examples=[cl.Example(request='q1', reply='r1_2')]),
       cl.ExampleTrainSimilarityMetadata(
           # train0, train1 and train2
           num_train_examples_with_same_request=3,
           # train1 and train2
           num_train_examples_with_same_request_and_reply=2,
           num_train_examples_with_same_request_and_output=2,
           # r1_1 and r1_2
           num_unique_train_replies=2,
           # Nearest = train1 (1/1)
           nearest_reply_matches=True,
           nearest_qualifier_matches=True,
           nearest_similarity=1.0,
           # Consensus = train1 (1/1)
           consensus_reply_matches=True,
           consensus_qualifier_matches=True)),
      ('nearest_value_differs_from_consensus_value',
       cl.ExampleGroup(
           context=_create_context(examples=[
               _EXAMPLE1, _EXAMPLE2, _EXAMPLE3, _EXAMPLE4, _EXAMPLE5, _EXAMPLE6,
               _EXAMPLE7, _EXAMPLE8]),
           examples=[cl.Example(request='q1', reply='r1_1')]),
       cl.ExampleTrainSimilarityMetadata(
           # train0, train1 and train2
           num_train_examples_with_same_request=3,
           # train0
           num_train_examples_with_same_request_and_reply=1,
           # None (since train0 has a different qualifier)
           num_train_examples_with_same_request_and_output=0,
           # r1_1 and r1_2
           num_unique_train_replies=2,
           # Nearest = train0 (3/8)
           nearest_reply_matches=True,
           nearest_qualifier_matches=False,
           nearest_similarity=0.375,
           # Consensus reply = r1_2: train1 (2/8) + train2 (2/8)
           consensus_reply_matches=False,
           # Consensus qualifier = M: train1 (2/8) + train2 (2/8)
           consensus_qualifier_matches=True)),
      ('only_examples_with_same_request_considered_for_similarity',
       cl.ExampleGroup(
           context=_create_context(examples=[
               _EXAMPLE1, _EXAMPLE2, _EXAMPLE3, _EXAMPLE8]),
           examples=[cl.Example(request='q2', reply='r2_1')]),
       cl.ExampleTrainSimilarityMetadata(
           # train3
           num_train_examples_with_same_request=1,
           num_train_examples_with_same_request_and_reply=1,
           num_train_examples_with_same_request_and_output=1,
           # r2_1
           num_unique_train_replies=1,
           # Nearest = train3 (1/4)
           # ... not train0 (3/4) because train0 has a different request.
           nearest_reply_matches=True,
           nearest_qualifier_matches=True,
           nearest_similarity=0.25,
           # Consensus = nearest (train3)
           consensus_reply_matches=True,
           consensus_qualifier_matches=True)),
      )
  # pyformat: enable
  def test_populate_train_similarity_metadata_yields_correct_example_metadata(
      self, test_example_group, expected_example_train_similarity_metadata):
    test_set = cl.GroupedExampleSet(example_groups=[test_example_group])
    train_set = cl.GroupedExampleSet(example_groups=[
        # train0
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE1, _EXAMPLE2, _EXAMPLE3]),
            examples=[
                cl.Example(
                    request='q1', reply='r1_1', qualifier=cl.Qualifier.D)
            ]),
        # train1: Same request as train0, but different output.
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE4, _EXAMPLE5]),
            examples=[cl.Example(request='q1', reply='r1_2')]),
        # train2: Same output as train1 (useful for testing consensus).
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE6, _EXAMPLE7]),
            examples=[cl.Example(request='q1', reply='r1_2')]),
        # train3: Different request.
        cl.ExampleGroup(
            context=_create_context(examples=[_EXAMPLE8]),
            examples=[cl.Example(request='q2', reply='r2_1')]),
    ])

    splits = {
        tfds.Split.TRAIN: train_set,
        tfds.Split.TEST: test_set,
    }

    similarity_metadata.populate_train_similarity_metadata(splits)

    test_example0 = splits[tfds.Split.TEST].example_groups[0][0]
    self.assertEqual(
        expected_example_train_similarity_metadata,
        test_example0.metadata.train_similarity,
        f'Full similarity metadata: {test_example0.metadata.train_similarity}')


if __name__ == '__main__':
  absltest.main()
