# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import itertools

from absl.testing import absltest
from absl.testing import parameterized
from language.compgen.nqg.tasks import mcd_utils
import numpy as np

from conceptual_learning.cscan import divergence_maximization
from conceptual_learning.cscan import inputs


class HelperFunctionsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('unseen_compound', {'x'}), ('compound_only_in_1', {'a'}),
      ('common_compound', {'b'}), ('multiple_common_compounds', {'b', 'c'}),
      ('compound_only_in_2', {'d'}))
  def test_compute_divergence_if_added_1(self, compounds_to_add):
    # The keys could be either atoms or compounds here, which we just call
    # compounds.
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    sum_1 = sum(counts_1.values())
    sum_2 = sum(counts_2.values())
    coef = 0.1
    original_divergence = mcd_utils.compute_divergence(counts_1, counts_2, coef)
    new_divergence = divergence_maximization._compute_divergence_if_added_1(
        counts_1, counts_2, sum_1, sum_2, compounds_to_add, original_divergence,
        coef)

    new_counts_1 = counts_1.copy()
    new_counts_1.update(compounds_to_add)
    expected_new_divergence = mcd_utils.compute_divergence(
        new_counts_1, counts_2, coef)

    self.assertAlmostEqual(new_divergence, expected_new_divergence)

  @parameterized.named_parameters(
      ('unseen_compound', {'x'}), ('compound_only_in_1', {'a'}),
      ('common_compound', {'b'}), ('multiple_common_compounds', {'b', 'c'}),
      ('compound_only_in_2', {'d'}))
  def test_compute_divergence_if_added_2(self, compounds_to_add):
    # The keys could be either atoms or compounds here, which we just call
    # compounds.
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    sum_1 = sum(counts_1.values())
    sum_2 = sum(counts_2.values())
    coef = 0.1
    original_divergence = mcd_utils.compute_divergence(counts_1, counts_2, coef)
    new_divergence = divergence_maximization._compute_divergence_if_added_2(
        counts_1, counts_2, sum_1, sum_2, compounds_to_add, original_divergence,
        coef)

    new_counts_2 = counts_2.copy()
    new_counts_2.update(compounds_to_add)
    expected_new_divergence = mcd_utils.compute_divergence(
        counts_1, new_counts_2, coef)

    self.assertAlmostEqual(new_divergence, expected_new_divergence)

  @parameterized.named_parameters(('compound_only_in_1', {'a'}),
                                  ('common_compound', {'b'}),
                                  ('multiple_common_compounds', {'b', 'c'}))
  def test_compute_divergence_if_removed_1(self, compounds_to_remove):
    # The keys could be either atoms or compounds here, which we just call
    # compounds.
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    sum_1 = sum(counts_1.values())
    sum_2 = sum(counts_2.values())
    coef = 0.1
    original_divergence = mcd_utils.compute_divergence(counts_1, counts_2, coef)
    new_divergence = divergence_maximization._compute_divergence_if_removed_1(
        counts_1, counts_2, sum_1, sum_2, compounds_to_remove,
        original_divergence, coef)

    new_counts_1 = counts_1.copy()
    new_counts_1.subtract(compounds_to_remove)
    expected_new_divergence = mcd_utils.compute_divergence(
        new_counts_1, counts_2, coef)

    self.assertAlmostEqual(new_divergence, expected_new_divergence)

  @parameterized.named_parameters(('compound_only_in_2', {'d'}),
                                  ('common_compound', {'b'}),
                                  ('multiple_common_compounds', {'b', 'c'}))
  def test_compute_divergence_if_removed_2(self, compounds_to_remove):
    # The keys could be either atoms or compounds here, which we just call
    # compounds.
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    sum_1 = sum(counts_1.values())
    sum_2 = sum(counts_2.values())
    coef = 0.1
    original_divergence = mcd_utils.compute_divergence(counts_1, counts_2, coef)
    new_divergence = divergence_maximization._compute_divergence_if_removed_2(
        counts_1, counts_2, sum_1, sum_2, compounds_to_remove,
        original_divergence, coef)

    new_counts_2 = counts_2.copy()
    new_counts_2.subtract(compounds_to_remove)
    expected_new_divergence = mcd_utils.compute_divergence(
        counts_1, new_counts_2, coef)

    self.assertAlmostEqual(new_divergence, expected_new_divergence)

  def test_compute_divergence_if_removed_1_should_raise_if_compound_missing(
      self):
    # The keys could be either atoms or compounds here, which we just call
    # compounds.
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    sum_1 = sum(counts_1.values())
    sum_2 = sum(counts_2.values())
    coef = 0.1
    original_divergence = mcd_utils.compute_divergence(counts_1, counts_2, coef)

    compounds_not_in_1 = {'d'}
    with self.assertRaisesRegex(
        ValueError, 'Item to be removed does not have positive count'):
      divergence_maximization._compute_divergence_if_removed_1(
          counts_1, counts_2, sum_1, sum_2, compounds_not_in_1,
          original_divergence, coef)

  def test_compute_divergence_if_removed_2_should_raise_if_compound_missing(
      self):
    # The keys could be either atoms or compounds here, which we just call
    # compounds.
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    sum_1 = sum(counts_1.values())
    sum_2 = sum(counts_2.values())
    coef = 0.1
    original_divergence = mcd_utils.compute_divergence(counts_1, counts_2, coef)

    compounds_not_in_2 = {'a'}
    with self.assertRaisesRegex(
        ValueError, 'Item to be removed does not have positive count'):
      divergence_maximization._compute_divergence_if_removed_2(
          counts_1, counts_2, sum_1, sum_2, compounds_not_in_2,
          original_divergence, coef)

  @parameterized.named_parameters(('atom_only_in_1', {'a'}, 3 / 4),
                                  ('common_atom', {'b'}, 3 / 4),
                                  ('multiple_common_atoms', {'b', 'c'}, 3 / 4),
                                  ('atom_only_in_2', {'d'}, 4 / 4))
  def test_compute_coverage_if_added_1(self, atoms_to_add, expected_coverage):
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    num_all = len(counts_1 | counts_2)

    coverage = divergence_maximization._compute_coverage_if_added_1(
        counts_1, num_all, atoms_to_add)

    self.assertEqual(expected_coverage, coverage)

  @parameterized.named_parameters(('atom_only_in_1', {'a'}, 3 / 4),
                                  ('common_atom', {'b'}, 3 / 4),
                                  ('multiple_common_atoms', {'b', 'c'}, 3 / 4),
                                  ('atom_only_in_2', {'d'}, 3 / 4))
  def test_compute_coverage_if_added_2(self, atoms_to_add, expected_coverage):
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    num_all = len(counts_1 | counts_2)

    coverage = divergence_maximization._compute_coverage_if_added_2(
        counts_1, num_all, atoms_to_add)

    self.assertEqual(expected_coverage, coverage)

  @parameterized.named_parameters(('atom_only_in_1', {'a'}, 2 / 4),
                                  ('common_atom', {'b'}, 3 / 4),
                                  ('multiple_common_atoms', {'b', 'c'}, 3 / 4))
  def test_compute_coverage_if_removed_1(self, atoms_to_remove,
                                         expected_coverage):
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 6})
    num_all = len(counts_1 | counts_2)

    coverage = divergence_maximization._compute_coverage_if_removed_1(
        counts_1, num_all, atoms_to_remove)

    self.assertEqual(expected_coverage, coverage)

  @parameterized.named_parameters(('atom_only_in_2', {'d'}, 3 / 4),
                                  ('common_atom', {'b'}, 3 / 4),
                                  ('multiple_common_atoms', {'b', 'c'}, 3 / 4))
  def test_compute_coverage_if_removed_2(self, atoms_to_remove,
                                         expected_coverage):
    counts_1 = collections.Counter({'a': 1, 'b': 2, 'c': 3})
    counts_2 = collections.Counter({'b': 4, 'c': 5, 'd': 1})
    num_all = len(counts_1 | counts_2)

    coverage = divergence_maximization._compute_coverage_if_removed_2(
        counts_1, num_all, atoms_to_remove)

    self.assertEqual(expected_coverage, coverage)


class DivergenceMaximizationTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.rng = np.random.RandomState(42)

  @parameterized.named_parameters(
      ('no_filter_no_sampling', 3, None, 3),
      ('no_filter_with_sampling', 2, None, 2),
      ('filter_no_sampling', 2, lambda x: x > 'a', 2),
      ('filter_with_sampling', 1, lambda x: x > 'a', 1),
      ('all_filtered', 2, lambda x: False, 2))
  def test_get_sampled_item_by_original_index(self, sample_size, filter_fn,
                                              expected_size):
    items = ['a', 'b', 'c']
    sampled_item_by_original_index = (
        divergence_maximization._get_sampled_item_by_original_index(
            items, sample_size, self.rng, filter_fn))

    self.assertLen(sampled_item_by_original_index, expected_size)

    for original_index, item in sampled_item_by_original_index.items():
      self.assertEqual(item, items[original_index])

  @parameterized.named_parameters(
      ('current_size_1_too_small', 1, 5,
       divergence_maximization._ChangeType.ADD_TO_1),
      ('current_size_2_too_small', 5, 1,
       divergence_maximization._ChangeType.ADD_TO_2),
      ('add_to_1', 4, 6, divergence_maximization._ChangeType.ADD_TO_1),
      ('add_to_2', 6, 4, divergence_maximization._ChangeType.ADD_TO_2))
  def test_should_add_item_if_total_size_not_in_sizes_to_delete(
      self, current_size_1, current_size_2, expected_change_type):
    num_items = 200
    size_1 = 20
    size_2 = 20
    # We choose delete_period and the current sizes so that the current total
    # size is not in sizes_to_delete.
    delete_period = 4
    sizes_to_delete = divergence_maximization._get_sizes_to_delete(
        num_items, delete_period)

    original_sizes_to_delete = sizes_to_delete[:]

    change_type = divergence_maximization._select_change_type(
        current_size_1, current_size_2, size_1, size_2, delete_period,
        sizes_to_delete, self.rng)
    self.assertEqual(change_type, expected_change_type)

    # If _select_change_type returns an "add" change type, then it should not
    # update sizes_to_delete.
    self.assertEqual(original_sizes_to_delete, sizes_to_delete)

  def test_should_remove_item_if_total_size_in_sizes_to_delete(self):
    num_items = 200
    size_1 = 20
    size_2 = 20
    delete_period = 4
    sizes_to_delete = divergence_maximization._get_sizes_to_delete(
        num_items, delete_period)

    current_size_1 = 7
    current_size_2 = 9
    current_total = current_size_1 + current_size_2
    expected_change_types = [
        divergence_maximization._ChangeType.REMOVE_FROM_1,
        divergence_maximization._ChangeType.REMOVE_FROM_2
    ]

    self.assertIn(current_total, sizes_to_delete)

    change_type = divergence_maximization._select_change_type(
        current_size_1, current_size_2, size_1, size_2, delete_period,
        sizes_to_delete, self.rng)

    self.assertIn(change_type, expected_change_types)

    # If _select_change_type returns a "remove" change type, then it also has
    # the side effect of removing current_total from sizes_to_delete.
    self.assertNotIn(current_total, sizes_to_delete)

  def test_maximize_divergence_should_increase_compound_divergence(self):
    # Tuples such as ('a', 'b', 'c', 'd').
    items = list(itertools.combinations('abcdefghijklm', 4))
    size_1 = int(0.9 * len(items))
    size_2 = len(items) - size_1
    options = inputs.CompoundDivergenceOptions(delete_period=5)

    compound_coef = 0.1

    # Atoms are: 'a', 'b', ...
    def get_atoms_fn(item):
      return set(item)

    # Compounds are: ('a', 'b'), ('a', 'c'), ...
    def get_compounds_fn(item):
      return set(itertools.combinations(item, 2))

    # Make a uniform random split to calculate the original divergences.
    indices_1 = self.rng.choice(len(items), size=size_1)
    uniform_items_1 = []
    uniform_items_2 = []
    for i, item in enumerate(items):
      if i in indices_1:
        uniform_items_1.append(item)
      else:
        uniform_items_2.append(item)

    original_compound_divergence = mcd_utils.measure_example_divergence(
        uniform_items_1, uniform_items_2, get_compounds_fn, coef=compound_coef)

    mcd_items_1, mcd_items_2 = divergence_maximization.maximize_divergence(
        items, size_1, size_2, get_compounds_fn, get_atoms_fn, options,
        self.rng)

    mcd_compound_divergence = mcd_utils.measure_example_divergence(
        mcd_items_1, mcd_items_2, get_compounds_fn, coef=compound_coef)

    with self.subTest('should_increase_compound_divergence'):
      self.assertGreater(mcd_compound_divergence, original_compound_divergence)


if __name__ == '__main__':
  absltest.main()
